pizb commited on
Commit
ece63b7
β€’
1 Parent(s): 418a437

wip: add training pipeline 1

Browse files
article_base_train_test.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import notebook_login
2
+ from datasets import load_dataset
3
+ from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig, TrainingArguments, Trainer
4
+ import torch
5
+ from peft import get_peft_model, LoraConfig
6
+
7
+
8
+ def main():
9
+ ds = load_dataset('HuggingFaceM4/VQAv2', split="train", trust_remote_code=True)
10
+ cols_remove = ["question_type", "answers", "answer_type", "image_id", "question_id"]
11
+ ds = ds.remove_columns(cols_remove)
12
+ ds = ds.train_test_split(test_size=0.1)
13
+ train_ds = ds["train"]
14
+ val_ds = ds["test"]
15
+
16
+ model_id = "google/paligemma-3b-pt-224"
17
+ processor = PaliGemmaProcessor.from_pretrained(model_id)
18
+ image_token = processor.tokenizer.convert_tokens_to_ids("<image>")
19
+ device = "cuda"
20
+
21
+ bnb_config = BitsAndBytesConfig(
22
+ load_in_4bit=True,
23
+ bnb_4bit_quant_type="nf4",
24
+ bnb_4bit_compute_type=torch.bfloat16
25
+ )
26
+ lora_config = LoraConfig(
27
+ r=8,
28
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
29
+ task_type="CAUSAL_LM",
30
+ )
31
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
32
+ model = get_peft_model(model, lora_config)
33
+ model.print_trainable_parameters()
34
+ #trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344
35
+
36
+ args=TrainingArguments(
37
+ num_train_epochs=2,
38
+ remove_unused_columns=False,
39
+ per_device_train_batch_size=16,
40
+ gradient_accumulation_steps=4,
41
+ warmup_steps=2,
42
+ learning_rate=2e-5,
43
+ weight_decay=1e-6,
44
+ adam_beta2=0.999,
45
+ logging_steps=100,
46
+ # optim="adamw_hf",
47
+ optim="paged_adamw_8bit", # for QLoRA
48
+ save_strategy="steps",
49
+ save_steps=1000,
50
+ push_to_hub=True,
51
+ save_total_limit=1,
52
+ bf16=True,
53
+ report_to=["tensorboard"],
54
+ dataloader_pin_memory=False
55
+ )
56
+
57
+ def collate_fn(examples):
58
+ texts = ["answer " + example["question"] for example in examples]
59
+ labels= [example['multiple_choice_answer'] for example in examples] # μš°λ¦¬λŠ” label 이 ν•„μš” 없을듯?
60
+ images = [example["image"].convert("RGB") for example in examples]
61
+ tokens = processor(text=texts, images=images, suffix=labels,
62
+ return_tensors="pt", padding="longest")
63
+
64
+ tokens = tokens.to(torch.bfloat16).to(device)
65
+ return tokens
66
+
67
+ trainer = Trainer(
68
+ model=model,
69
+ train_dataset=train_ds,
70
+ eval_dataset=val_ds,
71
+ data_collator=collate_fn,
72
+ args=args
73
+ )
74
+
75
+ trainer.train()
76
+
77
+
78
+ if __name__ == "__main__":
79
+ notebook_login()
80
+ main()
article_base_tutorial.ipynb ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "application/vnd.jupyter.widget-view+json": {
11
+ "model_id": "4d8a398ca84a42d7b745d8e32e6ad3dd",
12
+ "version_major": 2,
13
+ "version_minor": 0
14
+ },
15
+ "text/plain": [
16
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
17
+ ]
18
+ },
19
+ "metadata": {},
20
+ "output_type": "display_data"
21
+ }
22
+ ],
23
+ "source": [
24
+ "from huggingface_hub import notebook_login\n",
25
+ "notebook_login()"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {},
31
+ "source": [
32
+ "# Load Dataset"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 2,
38
+ "metadata": {},
39
+ "outputs": [
40
+ {
41
+ "name": "stderr",
42
+ "output_type": "stream",
43
+ "text": [
44
+ "Repo card metadata block was not found. Setting CardData to empty.\n"
45
+ ]
46
+ },
47
+ {
48
+ "data": {
49
+ "application/vnd.jupyter.widget-view+json": {
50
+ "model_id": "83c256fd38a143b6abded3fbf09d8bd8",
51
+ "version_major": 2,
52
+ "version_minor": 0
53
+ },
54
+ "text/plain": [
55
+ "Downloading data: 0%| | 0.00/13.5G [00:00<?, ?B/s]"
56
+ ]
57
+ },
58
+ "metadata": {},
59
+ "output_type": "display_data"
60
+ },
61
+ {
62
+ "ename": "FSTimeoutError",
63
+ "evalue": "",
64
+ "output_type": "error",
65
+ "traceback": [
66
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
67
+ "\u001b[0;31mTimeoutError\u001b[0m Traceback (most recent call last)",
68
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/fsspec/asyn.py:56\u001b[0m, in \u001b[0;36m_runner\u001b[0;34m(event, coro, result, timeout)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 56\u001b[0m result[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m coro\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m ex:\n",
69
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/fsspec/implementations/http.py:262\u001b[0m, in \u001b[0;36mHTTPFileSystem._get_file\u001b[0;34m(self, rpath, lpath, chunk_size, callback, **kwargs)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m chunk:\n\u001b[0;32m--> 262\u001b[0m chunk \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m r\u001b[38;5;241m.\u001b[39mcontent\u001b[38;5;241m.\u001b[39mread(chunk_size)\n\u001b[1;32m 263\u001b[0m outfile\u001b[38;5;241m.\u001b[39mwrite(chunk)\n",
70
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/aiohttp/streams.py:396\u001b[0m, in \u001b[0;36mStreamReader.read\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_buffer \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_eof:\n\u001b[0;32m--> 396\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wait(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mread\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 398\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_read_nowait(n)\n",
71
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/aiohttp/streams.py:314\u001b[0m, in \u001b[0;36mStreamReader._wait\u001b[0;34m(self, func_name)\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 314\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mwith\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_timer\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mawait\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mwaiter\u001b[49m\n",
72
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/aiohttp/helpers.py:719\u001b[0m, in \u001b[0;36mTimerContext.__exit__\u001b[0;34m(self, exc_type, exc_val, exc_tb)\u001b[0m\n\u001b[1;32m 718\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m exc_type \u001b[38;5;129;01mis\u001b[39;00m asyncio\u001b[38;5;241m.\u001b[39mCancelledError \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cancelled:\n\u001b[0;32m--> 719\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m asyncio\u001b[38;5;241m.\u001b[39mTimeoutError \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 720\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
73
+ "\u001b[0;31mTimeoutError\u001b[0m: ",
74
+ "\nThe above exception was the direct cause of the following exception:\n",
75
+ "\u001b[0;31mFSTimeoutError\u001b[0m Traceback (most recent call last)",
76
+ "Cell \u001b[0;32mIn[2], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_dataset \n\u001b[0;32m----> 2\u001b[0m ds \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mHuggingFaceM4/VQAv2\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m \n\u001b[1;32m 3\u001b[0m cols_remove \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquestion_type\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124manswers\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124manswer_type\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimage_id\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquestion_id\u001b[39m\u001b[38;5;124m\"\u001b[39m] \n\u001b[1;32m 4\u001b[0m ds \u001b[38;5;241m=\u001b[39m ds\u001b[38;5;241m.\u001b[39mremove_columns(cols_remove)\n",
77
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/load.py:2096\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, keep_in_memory, save_infos, revision, token, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[1;32m 2093\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m builder_instance\u001b[38;5;241m.\u001b[39mas_streaming_dataset(split\u001b[38;5;241m=\u001b[39msplit)\n\u001b[1;32m 2095\u001b[0m \u001b[38;5;66;03m# Download and prepare data\u001b[39;00m\n\u001b[0;32m-> 2096\u001b[0m \u001b[43mbuilder_instance\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2097\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2098\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2099\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2100\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_proc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2101\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2102\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2104\u001b[0m \u001b[38;5;66;03m# Build dataset for splits\u001b[39;00m\n\u001b[1;32m 2105\u001b[0m keep_in_memory \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2106\u001b[0m keep_in_memory \u001b[38;5;28;01mif\u001b[39;00m keep_in_memory \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m is_small_dataset(builder_instance\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size)\n\u001b[1;32m 2107\u001b[0m )\n",
78
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/builder.py:924\u001b[0m, in \u001b[0;36mDatasetBuilder.download_and_prepare\u001b[0;34m(self, output_dir, download_config, download_mode, verification_mode, dl_manager, base_path, file_format, max_shard_size, num_proc, storage_options, **download_and_prepare_kwargs)\u001b[0m\n\u001b[1;32m 922\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_proc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 923\u001b[0m prepare_split_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnum_proc\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m num_proc\n\u001b[0;32m--> 924\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 925\u001b[0m \u001b[43m \u001b[49m\u001b[43mdl_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdl_manager\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 926\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 927\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_split_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 928\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdownload_and_prepare_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 929\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 930\u001b[0m \u001b[38;5;66;03m# Sync info\u001b[39;00m\n\u001b[1;32m 931\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(split\u001b[38;5;241m.\u001b[39mnum_bytes \u001b[38;5;28;01mfor\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39msplits\u001b[38;5;241m.\u001b[39mvalues())\n",
79
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/builder.py:1647\u001b[0m, in \u001b[0;36mGeneratorBasedBuilder._download_and_prepare\u001b[0;34m(self, dl_manager, verification_mode, **prepare_splits_kwargs)\u001b[0m\n\u001b[1;32m 1646\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_download_and_prepare\u001b[39m(\u001b[38;5;28mself\u001b[39m, dl_manager, verification_mode, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mprepare_splits_kwargs):\n\u001b[0;32m-> 1647\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1648\u001b[0m \u001b[43m \u001b[49m\u001b[43mdl_manager\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1649\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1650\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_duplicate_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mVerificationMode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mBASIC_CHECKS\u001b[49m\n\u001b[1;32m 1651\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mVerificationMode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mALL_CHECKS\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1652\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_splits_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1653\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
80
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/builder.py:977\u001b[0m, in \u001b[0;36mDatasetBuilder._download_and_prepare\u001b[0;34m(self, dl_manager, verification_mode, **prepare_split_kwargs)\u001b[0m\n\u001b[1;32m 975\u001b[0m split_dict \u001b[38;5;241m=\u001b[39m SplitDict(dataset_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset_name)\n\u001b[1;32m 976\u001b[0m split_generators_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_split_generators_kwargs(prepare_split_kwargs)\n\u001b[0;32m--> 977\u001b[0m split_generators \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_split_generators\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdl_manager\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msplit_generators_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# Checksums verification\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m verification_mode \u001b[38;5;241m==\u001b[39m VerificationMode\u001b[38;5;241m.\u001b[39mALL_CHECKS \u001b[38;5;129;01mand\u001b[39;00m dl_manager\u001b[38;5;241m.\u001b[39mrecord_checksums:\n",
81
+ "File \u001b[0;32m~/.cache/huggingface/modules/datasets_modules/datasets/HuggingFaceM4--VQAv2/e4d008385143be7a6bd81e99483e671d5096942bcb987542217121a5ac2cb420/VQAv2.py:118\u001b[0m, in \u001b[0;36mVQAv2Dataset._split_generators\u001b[0;34m(self, dl_manager)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_split_generators\u001b[39m(\u001b[38;5;28mself\u001b[39m, dl_manager):\n\u001b[1;32m 117\u001b[0m \u001b[38;5;66;03m# urls = _URLS[self.config.name] # TODO later\u001b[39;00m\n\u001b[0;32m--> 118\u001b[0m data_dir \u001b[38;5;241m=\u001b[39m \u001b[43mdl_manager\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload_and_extract\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_URLS\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 119\u001b[0m gen_kwargs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 120\u001b[0m split_name: {\n\u001b[1;32m 121\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdir_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_path\u001b[39m\u001b[38;5;124m\"\u001b[39m: Path(data_dir[dir_name][split_name])\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m split_name \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mval\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest-dev\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 128\u001b[0m }\n\u001b[1;32m 129\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[1;32m 130\u001b[0m datasets\u001b[38;5;241m.\u001b[39mSplitGenerator(\n\u001b[1;32m 131\u001b[0m name\u001b[38;5;241m=\u001b[39mdatasets\u001b[38;5;241m.\u001b[39mSplit\u001b[38;5;241m.\u001b[39mTRAIN,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 145\u001b[0m ),\n\u001b[1;32m 146\u001b[0m ]\n",
82
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/download/download_manager.py:322\u001b[0m, in \u001b[0;36mDownloadManager.download_and_extract\u001b[0;34m(self, url_or_urls)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdownload_and_extract\u001b[39m(\u001b[38;5;28mself\u001b[39m, url_or_urls):\n\u001b[1;32m 307\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Download and extract given `url_or_urls`.\u001b[39;00m\n\u001b[1;32m 308\u001b[0m \n\u001b[1;32m 309\u001b[0m \u001b[38;5;124;03m Is roughly equivalent to:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;124;03m extracted_path(s): `str`, extracted paths of given URL(s).\u001b[39;00m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 322\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mextract(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl_or_urls\u001b[49m\u001b[43m)\u001b[49m)\n",
83
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/download/download_manager.py:159\u001b[0m, in \u001b[0;36mDownloadManager.download\u001b[0;34m(self, url_or_urls)\u001b[0m\n\u001b[1;32m 157\u001b[0m start_time \u001b[38;5;241m=\u001b[39m datetime\u001b[38;5;241m.\u001b[39mnow()\n\u001b[1;32m 158\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m stack_multiprocessing_download_progress_bars():\n\u001b[0;32m--> 159\u001b[0m downloaded_path_or_paths \u001b[38;5;241m=\u001b[39m \u001b[43mmap_nested\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 160\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 161\u001b[0m \u001b[43m \u001b[49m\u001b[43murl_or_urls\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 162\u001b[0m \u001b[43m \u001b[49m\u001b[43mmap_tuple\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_proc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mDownloading data files\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 165\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 167\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m duration \u001b[38;5;241m=\u001b[39m datetime\u001b[38;5;241m.\u001b[39mnow() \u001b[38;5;241m-\u001b[39m start_time\n\u001b[1;32m 169\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDownloading took \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mduration\u001b[38;5;241m.\u001b[39mtotal_seconds()\u001b[38;5;250m \u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m60\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m min\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
84
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/py_utils.py:512\u001b[0m, in \u001b[0;36mmap_nested\u001b[0;34m(function, data_struct, dict_only, map_list, map_tuple, map_numpy, num_proc, parallel_min_length, batched, batch_size, types, disable_tqdm, desc)\u001b[0m\n\u001b[1;32m 509\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(\u001b[38;5;28mlen\u001b[39m(iterable) \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m num_proc \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mint\u001b[39m(\u001b[38;5;28mlen\u001b[39m(iterable) \u001b[38;5;241m%\u001b[39m num_proc \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m), \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 510\u001b[0m iterable \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(iter_batched(iterable, batch_size))\n\u001b[1;32m 511\u001b[0m mapped \u001b[38;5;241m=\u001b[39m [\n\u001b[0;32m--> 512\u001b[0m \u001b[43m_single_map_nested\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 513\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m hf_tqdm(iterable, disable\u001b[38;5;241m=\u001b[39mdisable_tqdm, desc\u001b[38;5;241m=\u001b[39mdesc)\n\u001b[1;32m 514\u001b[0m ]\n\u001b[1;32m 515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batched:\n\u001b[1;32m 516\u001b[0m mapped \u001b[38;5;241m=\u001b[39m [mapped_item \u001b[38;5;28;01mfor\u001b[39;00m mapped_batch \u001b[38;5;129;01min\u001b[39;00m mapped \u001b[38;5;28;01mfor\u001b[39;00m mapped_item \u001b[38;5;129;01min\u001b[39;00m mapped_batch]\n",
85
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/py_utils.py:399\u001b[0m, in \u001b[0;36m_single_map_nested\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[1;32m 396\u001b[0m k: _single_map_nested((function, v, batched, batch_size, types, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)) \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m pbar\n\u001b[1;32m 397\u001b[0m }\n\u001b[1;32m 398\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 399\u001b[0m mapped \u001b[38;5;241m=\u001b[39m [\u001b[43m_single_map_nested\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m pbar]\n\u001b[1;32m 400\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_struct, \u001b[38;5;28mlist\u001b[39m):\n\u001b[1;32m 401\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m mapped\n",
86
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/py_utils.py:396\u001b[0m, in \u001b[0;36m_single_map_nested\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 393\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m hf_tqdm(pbar_iterable, disable\u001b[38;5;241m=\u001b[39mdisable_tqdm, position\u001b[38;5;241m=\u001b[39mrank, unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mobj\u001b[39m\u001b[38;5;124m\"\u001b[39m, desc\u001b[38;5;241m=\u001b[39mpbar_desc) \u001b[38;5;28;01mas\u001b[39;00m pbar:\n\u001b[1;32m 394\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_struct, \u001b[38;5;28mdict\u001b[39m):\n\u001b[1;32m 395\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[0;32m--> 396\u001b[0m k: \u001b[43m_single_map_nested\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m pbar\n\u001b[1;32m 397\u001b[0m }\n\u001b[1;32m 398\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 399\u001b[0m mapped \u001b[38;5;241m=\u001b[39m [_single_map_nested((function, v, batched, batch_size, types, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)) \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m pbar]\n",
87
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/py_utils.py:371\u001b[0m, in \u001b[0;36m_single_map_nested\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 369\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_struct, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_struct, types):\n\u001b[1;32m 370\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batched:\n\u001b[0;32m--> 371\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mdata_struct\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 372\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 373\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m function(data_struct)\n",
88
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/download/download_manager.py:216\u001b[0m, in \u001b[0;36mDownloadManager._download_batched\u001b[0;34m(self, url_or_filenames, download_config)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m thread_map(\n\u001b[1;32m 203\u001b[0m download_func,\n\u001b[1;32m 204\u001b[0m url_or_filenames,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 212\u001b[0m tqdm_class\u001b[38;5;241m=\u001b[39mtqdm,\n\u001b[1;32m 213\u001b[0m )\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[0;32m--> 216\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_single\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl_or_filename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 217\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m url_or_filename \u001b[38;5;129;01min\u001b[39;00m url_or_filenames\n\u001b[1;32m 218\u001b[0m ]\n",
89
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/download/download_manager.py:225\u001b[0m, in \u001b[0;36mDownloadManager._download_single\u001b[0;34m(self, url_or_filename, download_config)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_relative_path(url_or_filename):\n\u001b[1;32m 223\u001b[0m \u001b[38;5;66;03m# append the relative path to the base_path\u001b[39;00m\n\u001b[1;32m 224\u001b[0m url_or_filename \u001b[38;5;241m=\u001b[39m url_or_path_join(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_base_path, url_or_filename)\n\u001b[0;32m--> 225\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mcached_path\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl_or_filename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 226\u001b[0m out \u001b[38;5;241m=\u001b[39m tracked_str(out)\n\u001b[1;32m 227\u001b[0m out\u001b[38;5;241m.\u001b[39mset_origin(url_or_filename)\n",
90
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/file_utils.py:205\u001b[0m, in \u001b[0;36mcached_path\u001b[0;34m(url_or_filename, download_config, **download_kwargs)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\u001b[38;5;28mstr\u001b[39m(e)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;66;03m# Download external files\u001b[39;00m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 205\u001b[0m output_path \u001b[38;5;241m=\u001b[39m \u001b[43mget_from_cache\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[43m \u001b[49m\u001b[43murl_or_filename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 207\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 208\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 209\u001b[0m \u001b[43m \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 210\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_etag\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_etag\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 211\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 212\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 213\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_desc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload_desc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 214\u001b[0m \u001b[43m \u001b[49m\u001b[43mdisable_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdisable_tqdm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 215\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mexists(url_or_filename):\n\u001b[1;32m 217\u001b[0m \u001b[38;5;66;03m# File, and it exists.\u001b[39;00m\n\u001b[1;32m 218\u001b[0m output_path \u001b[38;5;241m=\u001b[39m url_or_filename\n",
91
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/file_utils.py:415\u001b[0m, in \u001b[0;36mget_from_cache\u001b[0;34m(url, cache_dir, force_download, user_agent, use_etag, token, storage_options, download_desc, disable_tqdm)\u001b[0m\n\u001b[1;32m 413\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00murl\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m not found in cache or force_download set to True, downloading to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtemp_file\u001b[38;5;241m.\u001b[39mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 414\u001b[0m \u001b[38;5;66;03m# GET file object\u001b[39;00m\n\u001b[0;32m--> 415\u001b[0m \u001b[43mfsspec_get\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemp_file\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_desc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisable_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdisable_tqdm\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 417\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstoring \u001b[39m\u001b[38;5;132;01m{\u001b[39;00murl\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m in cache at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcache_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 418\u001b[0m shutil\u001b[38;5;241m.\u001b[39mmove(temp_file\u001b[38;5;241m.\u001b[39mname, cache_path)\n",
92
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/file_utils.py:334\u001b[0m, in \u001b[0;36mfsspec_get\u001b[0;34m(url, temp_file, storage_options, desc, disable_tqdm)\u001b[0m\n\u001b[1;32m 321\u001b[0m fs, path \u001b[38;5;241m=\u001b[39m url_to_fs(url, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m(storage_options \u001b[38;5;129;01mor\u001b[39;00m {}))\n\u001b[1;32m 322\u001b[0m callback \u001b[38;5;241m=\u001b[39m TqdmCallback(\n\u001b[1;32m 323\u001b[0m tqdm_kwargs\u001b[38;5;241m=\u001b[39m{\n\u001b[1;32m 324\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdesc\u001b[39m\u001b[38;5;124m\"\u001b[39m: desc \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDownloading\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 332\u001b[0m }\n\u001b[1;32m 333\u001b[0m )\n\u001b[0;32m--> 334\u001b[0m \u001b[43mfs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemp_file\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m)\u001b[49m\n",
93
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/fsspec/asyn.py:118\u001b[0m, in \u001b[0;36msync_wrapper.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28mself\u001b[39m \u001b[38;5;241m=\u001b[39m obj \u001b[38;5;129;01mor\u001b[39;00m args[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m--> 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43msync\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloop\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
94
+ "File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/fsspec/asyn.py:101\u001b[0m, in \u001b[0;36msync\u001b[0;34m(loop, func, timeout, *args, **kwargs)\u001b[0m\n\u001b[1;32m 98\u001b[0m return_result \u001b[38;5;241m=\u001b[39m result[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(return_result, asyncio\u001b[38;5;241m.\u001b[39mTimeoutError):\n\u001b[1;32m 100\u001b[0m \u001b[38;5;66;03m# suppress asyncio.TimeoutError, raise FSTimeoutError\u001b[39;00m\n\u001b[0;32m--> 101\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m FSTimeoutError \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mreturn_result\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(return_result, \u001b[38;5;167;01mBaseException\u001b[39;00m):\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m return_result\n",
95
+ "\u001b[0;31mFSTimeoutError\u001b[0m: "
96
+ ]
97
+ }
98
+ ],
99
+ "source": [
100
+ "from datasets import load_dataset \n",
101
+ "ds = load_dataset('HuggingFaceM4/VQAv2', split=\"train\", trust_remote_code=True) \n",
102
+ "cols_remove = [\"question_type\", \"answers\", \"answer_type\", \"image_id\", \"question_id\"] \n",
103
+ "ds = ds.remove_columns(cols_remove)\n",
104
+ "ds = ds.train_test_split(test_size=0.1)\n",
105
+ "train_ds = ds[\"train\"]\n",
106
+ "val_ds = ds[\"test\"]"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "markdown",
111
+ "metadata": {},
112
+ "source": [
113
+ "# Train (QLoRA 4-bit)"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "from transformers import PaliGemmaProcessor \n",
123
+ "model_id = \"google/paligemma-3b-pt-224\"\n",
124
+ "processor = PaliGemmaProcessor.from_pretrained(model_id)"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": null,
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": [
133
+ "import torch\n",
134
+ "device = \"cuda\"\n",
135
+ "\n",
136
+ "image_token = processor.tokenizer.convert_tokens_to_ids(\"<image>\")\n",
137
+ "def collate_fn(examples):\n",
138
+ " texts = [\"answer \" + example[\"question\"] for example in examples]\n",
139
+ " labels= [example['multiple_choice_answer'] for example in examples] # μš°λ¦¬λŠ” label 이 ν•„μš” 없을듯?\n",
140
+ " images = [example[\"image\"].convert(\"RGB\") for example in examples]\n",
141
+ " tokens = processor(text=texts, images=images, suffix=labels,\n",
142
+ " return_tensors=\"pt\", padding=\"longest\")\n",
143
+ "\n",
144
+ " tokens = tokens.to(torch.bfloat16).to(device)\n",
145
+ " return tokens"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "from transformers import PaliGemmaForConditionalGeneration\n",
155
+ "model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)\n",
156
+ "\n",
157
+ "# Freeze the vision tower and the transformer encoder (image encoder)\n",
158
+ "for param in model.vision_tower.parameters():\n",
159
+ " param.requires_grad = False\n",
160
+ "\n",
161
+ "# Projector is not frozen (Article μ—μ„œλŠ” freeze ν•œλ‹€κ³  λ˜μ–΄μžˆμŒ)\n",
162
+ "for param in model.multi_modal_projector.parameters():\n",
163
+ " param.requires_grad = True\n"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "markdown",
168
+ "metadata": {},
169
+ "source": [
170
+ "For QLoRa in 4-bit"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "from transformers import BitsAndBytesConfig\n",
180
+ "from peft import get_peft_model, LoraConfig\n",
181
+ "\n",
182
+ "bnb_config = BitsAndBytesConfig(\n",
183
+ " load_in_4bit=True,\n",
184
+ " bnb_4bit_quant_type=\"nf4\",\n",
185
+ " bnb_4bit_compute_type=torch.bfloat16\n",
186
+ ")\n",
187
+ "\n",
188
+ "lora_config = LoraConfig(\n",
189
+ " r=8, \n",
190
+ " target_modules=[\"q_proj\", \"o_proj\", \"k_proj\", \"v_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
191
+ " task_type=\"CAUSAL_LM\",\n",
192
+ ")\n",
193
+ "model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={\"\":0})\n",
194
+ "model = get_peft_model(model, lora_config)\n",
195
+ "model.print_trainable_parameters()\n",
196
+ "#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344\n"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "metadata": {},
203
+ "outputs": [],
204
+ "source": [
205
+ "from transformers import TrainingArguments\n",
206
+ "args=TrainingArguments(\n",
207
+ " num_train_epochs=2,\n",
208
+ " remove_unused_columns=False,\n",
209
+ " per_device_train_batch_size=16,\n",
210
+ " gradient_accumulation_steps=4,\n",
211
+ " warmup_steps=2,\n",
212
+ " learning_rate=2e-5,\n",
213
+ " weight_decay=1e-6,\n",
214
+ " adam_beta2=0.999,\n",
215
+ " logging_steps=100,\n",
216
+ " # optim=\"adamw_hf\",\n",
217
+ " optim=\"paged_adamw_8bit\", # for QLoRA\n",
218
+ " save_strategy=\"steps\",\n",
219
+ " save_steps=1000,\n",
220
+ " push_to_hub=True,\n",
221
+ " save_total_limit=1,\n",
222
+ " bf16=True,\n",
223
+ " report_to=[\"tensorboard\"],\n",
224
+ " dataloader_pin_memory=False\n",
225
+ " )\n"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "metadata": {},
232
+ "outputs": [],
233
+ "source": [
234
+ "from transformers import Trainer\n",
235
+ "trainer = Trainer(\n",
236
+ " model=model,\n",
237
+ " train_dataset=train_ds,\n",
238
+ " eval_dataset=val_ds,\n",
239
+ " data_collator=collate_fn,\n",
240
+ " args=args\n",
241
+ " )"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": null,
247
+ "metadata": {},
248
+ "outputs": [],
249
+ "source": [
250
+ "trainer.train()"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "markdown",
255
+ "metadata": {},
256
+ "source": [
257
+ "# Inference for test"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": []
266
+ }
267
+ ],
268
+ "metadata": {
269
+ "kernelspec": {
270
+ "display_name": "Python 3",
271
+ "language": "python",
272
+ "name": "python3"
273
+ },
274
+ "language_info": {
275
+ "codemirror_mode": {
276
+ "name": "ipython",
277
+ "version": 3
278
+ },
279
+ "file_extension": ".py",
280
+ "mimetype": "text/x-python",
281
+ "name": "python",
282
+ "nbconvert_exporter": "python",
283
+ "pygments_lexer": "ipython3",
284
+ "version": "3.12.2"
285
+ }
286
+ },
287
+ "nbformat": 4,
288
+ "nbformat_minor": 2
289
+ }