wip: add training pipeline 1
Browse files- article_base_train_test.py +80 -0
- article_base_tutorial.ipynb +289 -0
article_base_train_test.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import notebook_login
|
2 |
+
from datasets import load_dataset
|
3 |
+
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig, TrainingArguments, Trainer
|
4 |
+
import torch
|
5 |
+
from peft import get_peft_model, LoraConfig
|
6 |
+
|
7 |
+
|
8 |
+
def main():
|
9 |
+
ds = load_dataset('HuggingFaceM4/VQAv2', split="train", trust_remote_code=True)
|
10 |
+
cols_remove = ["question_type", "answers", "answer_type", "image_id", "question_id"]
|
11 |
+
ds = ds.remove_columns(cols_remove)
|
12 |
+
ds = ds.train_test_split(test_size=0.1)
|
13 |
+
train_ds = ds["train"]
|
14 |
+
val_ds = ds["test"]
|
15 |
+
|
16 |
+
model_id = "google/paligemma-3b-pt-224"
|
17 |
+
processor = PaliGemmaProcessor.from_pretrained(model_id)
|
18 |
+
image_token = processor.tokenizer.convert_tokens_to_ids("<image>")
|
19 |
+
device = "cuda"
|
20 |
+
|
21 |
+
bnb_config = BitsAndBytesConfig(
|
22 |
+
load_in_4bit=True,
|
23 |
+
bnb_4bit_quant_type="nf4",
|
24 |
+
bnb_4bit_compute_type=torch.bfloat16
|
25 |
+
)
|
26 |
+
lora_config = LoraConfig(
|
27 |
+
r=8,
|
28 |
+
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
|
29 |
+
task_type="CAUSAL_LM",
|
30 |
+
)
|
31 |
+
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
|
32 |
+
model = get_peft_model(model, lora_config)
|
33 |
+
model.print_trainable_parameters()
|
34 |
+
#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344
|
35 |
+
|
36 |
+
args=TrainingArguments(
|
37 |
+
num_train_epochs=2,
|
38 |
+
remove_unused_columns=False,
|
39 |
+
per_device_train_batch_size=16,
|
40 |
+
gradient_accumulation_steps=4,
|
41 |
+
warmup_steps=2,
|
42 |
+
learning_rate=2e-5,
|
43 |
+
weight_decay=1e-6,
|
44 |
+
adam_beta2=0.999,
|
45 |
+
logging_steps=100,
|
46 |
+
# optim="adamw_hf",
|
47 |
+
optim="paged_adamw_8bit", # for QLoRA
|
48 |
+
save_strategy="steps",
|
49 |
+
save_steps=1000,
|
50 |
+
push_to_hub=True,
|
51 |
+
save_total_limit=1,
|
52 |
+
bf16=True,
|
53 |
+
report_to=["tensorboard"],
|
54 |
+
dataloader_pin_memory=False
|
55 |
+
)
|
56 |
+
|
57 |
+
def collate_fn(examples):
|
58 |
+
texts = ["answer " + example["question"] for example in examples]
|
59 |
+
labels= [example['multiple_choice_answer'] for example in examples] # μ°λ¦¬λ label μ΄ νμ μμλ―?
|
60 |
+
images = [example["image"].convert("RGB") for example in examples]
|
61 |
+
tokens = processor(text=texts, images=images, suffix=labels,
|
62 |
+
return_tensors="pt", padding="longest")
|
63 |
+
|
64 |
+
tokens = tokens.to(torch.bfloat16).to(device)
|
65 |
+
return tokens
|
66 |
+
|
67 |
+
trainer = Trainer(
|
68 |
+
model=model,
|
69 |
+
train_dataset=train_ds,
|
70 |
+
eval_dataset=val_ds,
|
71 |
+
data_collator=collate_fn,
|
72 |
+
args=args
|
73 |
+
)
|
74 |
+
|
75 |
+
trainer.train()
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
notebook_login()
|
80 |
+
main()
|
article_base_tutorial.ipynb
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"data": {
|
10 |
+
"application/vnd.jupyter.widget-view+json": {
|
11 |
+
"model_id": "4d8a398ca84a42d7b745d8e32e6ad3dd",
|
12 |
+
"version_major": 2,
|
13 |
+
"version_minor": 0
|
14 |
+
},
|
15 |
+
"text/plain": [
|
16 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.svβ¦"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
"metadata": {},
|
20 |
+
"output_type": "display_data"
|
21 |
+
}
|
22 |
+
],
|
23 |
+
"source": [
|
24 |
+
"from huggingface_hub import notebook_login\n",
|
25 |
+
"notebook_login()"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "markdown",
|
30 |
+
"metadata": {},
|
31 |
+
"source": [
|
32 |
+
"# Load Dataset"
|
33 |
+
]
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"cell_type": "code",
|
37 |
+
"execution_count": 2,
|
38 |
+
"metadata": {},
|
39 |
+
"outputs": [
|
40 |
+
{
|
41 |
+
"name": "stderr",
|
42 |
+
"output_type": "stream",
|
43 |
+
"text": [
|
44 |
+
"Repo card metadata block was not found. Setting CardData to empty.\n"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"data": {
|
49 |
+
"application/vnd.jupyter.widget-view+json": {
|
50 |
+
"model_id": "83c256fd38a143b6abded3fbf09d8bd8",
|
51 |
+
"version_major": 2,
|
52 |
+
"version_minor": 0
|
53 |
+
},
|
54 |
+
"text/plain": [
|
55 |
+
"Downloading data: 0%| | 0.00/13.5G [00:00<?, ?B/s]"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
"metadata": {},
|
59 |
+
"output_type": "display_data"
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"ename": "FSTimeoutError",
|
63 |
+
"evalue": "",
|
64 |
+
"output_type": "error",
|
65 |
+
"traceback": [
|
66 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
67 |
+
"\u001b[0;31mTimeoutError\u001b[0m Traceback (most recent call last)",
|
68 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/fsspec/asyn.py:56\u001b[0m, in \u001b[0;36m_runner\u001b[0;34m(event, coro, result, timeout)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 56\u001b[0m result[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m coro\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m ex:\n",
|
69 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/fsspec/implementations/http.py:262\u001b[0m, in \u001b[0;36mHTTPFileSystem._get_file\u001b[0;34m(self, rpath, lpath, chunk_size, callback, **kwargs)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m chunk:\n\u001b[0;32m--> 262\u001b[0m chunk \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m r\u001b[38;5;241m.\u001b[39mcontent\u001b[38;5;241m.\u001b[39mread(chunk_size)\n\u001b[1;32m 263\u001b[0m outfile\u001b[38;5;241m.\u001b[39mwrite(chunk)\n",
|
70 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/aiohttp/streams.py:396\u001b[0m, in \u001b[0;36mStreamReader.read\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_buffer \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_eof:\n\u001b[0;32m--> 396\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wait(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mread\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 398\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_read_nowait(n)\n",
|
71 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/aiohttp/streams.py:314\u001b[0m, in \u001b[0;36mStreamReader._wait\u001b[0;34m(self, func_name)\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 314\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mwith\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_timer\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mawait\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mwaiter\u001b[49m\n",
|
72 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/aiohttp/helpers.py:719\u001b[0m, in \u001b[0;36mTimerContext.__exit__\u001b[0;34m(self, exc_type, exc_val, exc_tb)\u001b[0m\n\u001b[1;32m 718\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m exc_type \u001b[38;5;129;01mis\u001b[39;00m asyncio\u001b[38;5;241m.\u001b[39mCancelledError \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cancelled:\n\u001b[0;32m--> 719\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m asyncio\u001b[38;5;241m.\u001b[39mTimeoutError \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 720\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
73 |
+
"\u001b[0;31mTimeoutError\u001b[0m: ",
|
74 |
+
"\nThe above exception was the direct cause of the following exception:\n",
|
75 |
+
"\u001b[0;31mFSTimeoutError\u001b[0m Traceback (most recent call last)",
|
76 |
+
"Cell \u001b[0;32mIn[2], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatasets\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_dataset \n\u001b[0;32m----> 2\u001b[0m ds \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mHuggingFaceM4/VQAv2\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m \n\u001b[1;32m 3\u001b[0m cols_remove \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquestion_type\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124manswers\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124manswer_type\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimage_id\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquestion_id\u001b[39m\u001b[38;5;124m\"\u001b[39m] \n\u001b[1;32m 4\u001b[0m ds \u001b[38;5;241m=\u001b[39m ds\u001b[38;5;241m.\u001b[39mremove_columns(cols_remove)\n",
|
77 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/load.py:2096\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, keep_in_memory, save_infos, revision, token, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[1;32m 2093\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m builder_instance\u001b[38;5;241m.\u001b[39mas_streaming_dataset(split\u001b[38;5;241m=\u001b[39msplit)\n\u001b[1;32m 2095\u001b[0m \u001b[38;5;66;03m# Download and prepare data\u001b[39;00m\n\u001b[0;32m-> 2096\u001b[0m \u001b[43mbuilder_instance\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2097\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2098\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2099\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2100\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_proc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2101\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2102\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2104\u001b[0m \u001b[38;5;66;03m# Build dataset for splits\u001b[39;00m\n\u001b[1;32m 2105\u001b[0m keep_in_memory \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2106\u001b[0m keep_in_memory \u001b[38;5;28;01mif\u001b[39;00m keep_in_memory \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m is_small_dataset(builder_instance\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size)\n\u001b[1;32m 2107\u001b[0m )\n",
|
78 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/builder.py:924\u001b[0m, in \u001b[0;36mDatasetBuilder.download_and_prepare\u001b[0;34m(self, output_dir, download_config, download_mode, verification_mode, dl_manager, base_path, file_format, max_shard_size, num_proc, storage_options, **download_and_prepare_kwargs)\u001b[0m\n\u001b[1;32m 922\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_proc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 923\u001b[0m prepare_split_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnum_proc\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m num_proc\n\u001b[0;32m--> 924\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 925\u001b[0m \u001b[43m \u001b[49m\u001b[43mdl_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdl_manager\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 926\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 927\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_split_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 928\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdownload_and_prepare_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 929\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 930\u001b[0m \u001b[38;5;66;03m# Sync info\u001b[39;00m\n\u001b[1;32m 931\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(split\u001b[38;5;241m.\u001b[39mnum_bytes \u001b[38;5;28;01mfor\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39msplits\u001b[38;5;241m.\u001b[39mvalues())\n",
|
79 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/builder.py:1647\u001b[0m, in \u001b[0;36mGeneratorBasedBuilder._download_and_prepare\u001b[0;34m(self, dl_manager, verification_mode, **prepare_splits_kwargs)\u001b[0m\n\u001b[1;32m 1646\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_download_and_prepare\u001b[39m(\u001b[38;5;28mself\u001b[39m, dl_manager, verification_mode, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mprepare_splits_kwargs):\n\u001b[0;32m-> 1647\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1648\u001b[0m \u001b[43m \u001b[49m\u001b[43mdl_manager\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1649\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1650\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_duplicate_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mVerificationMode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mBASIC_CHECKS\u001b[49m\n\u001b[1;32m 1651\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mVerificationMode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mALL_CHECKS\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1652\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_splits_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1653\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
|
80 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/builder.py:977\u001b[0m, in \u001b[0;36mDatasetBuilder._download_and_prepare\u001b[0;34m(self, dl_manager, verification_mode, **prepare_split_kwargs)\u001b[0m\n\u001b[1;32m 975\u001b[0m split_dict \u001b[38;5;241m=\u001b[39m SplitDict(dataset_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset_name)\n\u001b[1;32m 976\u001b[0m split_generators_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_split_generators_kwargs(prepare_split_kwargs)\n\u001b[0;32m--> 977\u001b[0m split_generators \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_split_generators\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdl_manager\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msplit_generators_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# Checksums verification\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m verification_mode \u001b[38;5;241m==\u001b[39m VerificationMode\u001b[38;5;241m.\u001b[39mALL_CHECKS \u001b[38;5;129;01mand\u001b[39;00m dl_manager\u001b[38;5;241m.\u001b[39mrecord_checksums:\n",
|
81 |
+
"File \u001b[0;32m~/.cache/huggingface/modules/datasets_modules/datasets/HuggingFaceM4--VQAv2/e4d008385143be7a6bd81e99483e671d5096942bcb987542217121a5ac2cb420/VQAv2.py:118\u001b[0m, in \u001b[0;36mVQAv2Dataset._split_generators\u001b[0;34m(self, dl_manager)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_split_generators\u001b[39m(\u001b[38;5;28mself\u001b[39m, dl_manager):\n\u001b[1;32m 117\u001b[0m \u001b[38;5;66;03m# urls = _URLS[self.config.name] # TODO later\u001b[39;00m\n\u001b[0;32m--> 118\u001b[0m data_dir \u001b[38;5;241m=\u001b[39m \u001b[43mdl_manager\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload_and_extract\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_URLS\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 119\u001b[0m gen_kwargs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 120\u001b[0m split_name: {\n\u001b[1;32m 121\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdir_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_path\u001b[39m\u001b[38;5;124m\"\u001b[39m: Path(data_dir[dir_name][split_name])\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m split_name \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mval\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest-dev\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 128\u001b[0m }\n\u001b[1;32m 129\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[1;32m 130\u001b[0m datasets\u001b[38;5;241m.\u001b[39mSplitGenerator(\n\u001b[1;32m 131\u001b[0m name\u001b[38;5;241m=\u001b[39mdatasets\u001b[38;5;241m.\u001b[39mSplit\u001b[38;5;241m.\u001b[39mTRAIN,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 145\u001b[0m ),\n\u001b[1;32m 146\u001b[0m ]\n",
|
82 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/download/download_manager.py:322\u001b[0m, in \u001b[0;36mDownloadManager.download_and_extract\u001b[0;34m(self, url_or_urls)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdownload_and_extract\u001b[39m(\u001b[38;5;28mself\u001b[39m, url_or_urls):\n\u001b[1;32m 307\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Download and extract given `url_or_urls`.\u001b[39;00m\n\u001b[1;32m 308\u001b[0m \n\u001b[1;32m 309\u001b[0m \u001b[38;5;124;03m Is roughly equivalent to:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;124;03m extracted_path(s): `str`, extracted paths of given URL(s).\u001b[39;00m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 322\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mextract(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl_or_urls\u001b[49m\u001b[43m)\u001b[49m)\n",
|
83 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/download/download_manager.py:159\u001b[0m, in \u001b[0;36mDownloadManager.download\u001b[0;34m(self, url_or_urls)\u001b[0m\n\u001b[1;32m 157\u001b[0m start_time \u001b[38;5;241m=\u001b[39m datetime\u001b[38;5;241m.\u001b[39mnow()\n\u001b[1;32m 158\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m stack_multiprocessing_download_progress_bars():\n\u001b[0;32m--> 159\u001b[0m downloaded_path_or_paths \u001b[38;5;241m=\u001b[39m \u001b[43mmap_nested\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 160\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_func\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 161\u001b[0m \u001b[43m \u001b[49m\u001b[43murl_or_urls\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 162\u001b[0m \u001b[43m \u001b[49m\u001b[43mmap_tuple\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_proc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mDownloading data files\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 165\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 167\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m duration \u001b[38;5;241m=\u001b[39m datetime\u001b[38;5;241m.\u001b[39mnow() \u001b[38;5;241m-\u001b[39m start_time\n\u001b[1;32m 169\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDownloading took \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mduration\u001b[38;5;241m.\u001b[39mtotal_seconds()\u001b[38;5;250m \u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m60\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m min\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
84 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/py_utils.py:512\u001b[0m, in \u001b[0;36mmap_nested\u001b[0;34m(function, data_struct, dict_only, map_list, map_tuple, map_numpy, num_proc, parallel_min_length, batched, batch_size, types, disable_tqdm, desc)\u001b[0m\n\u001b[1;32m 509\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(\u001b[38;5;28mlen\u001b[39m(iterable) \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m num_proc \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mint\u001b[39m(\u001b[38;5;28mlen\u001b[39m(iterable) \u001b[38;5;241m%\u001b[39m num_proc \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m), \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 510\u001b[0m iterable \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(iter_batched(iterable, batch_size))\n\u001b[1;32m 511\u001b[0m mapped \u001b[38;5;241m=\u001b[39m [\n\u001b[0;32m--> 512\u001b[0m \u001b[43m_single_map_nested\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 513\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m hf_tqdm(iterable, disable\u001b[38;5;241m=\u001b[39mdisable_tqdm, desc\u001b[38;5;241m=\u001b[39mdesc)\n\u001b[1;32m 514\u001b[0m ]\n\u001b[1;32m 515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batched:\n\u001b[1;32m 516\u001b[0m mapped \u001b[38;5;241m=\u001b[39m [mapped_item \u001b[38;5;28;01mfor\u001b[39;00m mapped_batch \u001b[38;5;129;01min\u001b[39;00m mapped \u001b[38;5;28;01mfor\u001b[39;00m mapped_item \u001b[38;5;129;01min\u001b[39;00m mapped_batch]\n",
|
85 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/py_utils.py:399\u001b[0m, in \u001b[0;36m_single_map_nested\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[1;32m 396\u001b[0m k: _single_map_nested((function, v, batched, batch_size, types, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)) \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m pbar\n\u001b[1;32m 397\u001b[0m }\n\u001b[1;32m 398\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 399\u001b[0m mapped \u001b[38;5;241m=\u001b[39m [\u001b[43m_single_map_nested\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m pbar]\n\u001b[1;32m 400\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_struct, \u001b[38;5;28mlist\u001b[39m):\n\u001b[1;32m 401\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m mapped\n",
|
86 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/py_utils.py:396\u001b[0m, in \u001b[0;36m_single_map_nested\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 393\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m hf_tqdm(pbar_iterable, disable\u001b[38;5;241m=\u001b[39mdisable_tqdm, position\u001b[38;5;241m=\u001b[39mrank, unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mobj\u001b[39m\u001b[38;5;124m\"\u001b[39m, desc\u001b[38;5;241m=\u001b[39mpbar_desc) \u001b[38;5;28;01mas\u001b[39;00m pbar:\n\u001b[1;32m 394\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_struct, \u001b[38;5;28mdict\u001b[39m):\n\u001b[1;32m 395\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[0;32m--> 396\u001b[0m k: \u001b[43m_single_map_nested\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m pbar\n\u001b[1;32m 397\u001b[0m }\n\u001b[1;32m 398\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 399\u001b[0m mapped \u001b[38;5;241m=\u001b[39m [_single_map_nested((function, v, batched, batch_size, types, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)) \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m pbar]\n",
|
87 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/py_utils.py:371\u001b[0m, in \u001b[0;36m_single_map_nested\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 369\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_struct, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_struct, types):\n\u001b[1;32m 370\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batched:\n\u001b[0;32m--> 371\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mdata_struct\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 372\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 373\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m function(data_struct)\n",
|
88 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/download/download_manager.py:216\u001b[0m, in \u001b[0;36mDownloadManager._download_batched\u001b[0;34m(self, url_or_filenames, download_config)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m thread_map(\n\u001b[1;32m 203\u001b[0m download_func,\n\u001b[1;32m 204\u001b[0m url_or_filenames,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 212\u001b[0m tqdm_class\u001b[38;5;241m=\u001b[39mtqdm,\n\u001b[1;32m 213\u001b[0m )\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[0;32m--> 216\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_single\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl_or_filename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 217\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m url_or_filename \u001b[38;5;129;01min\u001b[39;00m url_or_filenames\n\u001b[1;32m 218\u001b[0m ]\n",
|
89 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/download/download_manager.py:225\u001b[0m, in \u001b[0;36mDownloadManager._download_single\u001b[0;34m(self, url_or_filename, download_config)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_relative_path(url_or_filename):\n\u001b[1;32m 223\u001b[0m \u001b[38;5;66;03m# append the relative path to the base_path\u001b[39;00m\n\u001b[1;32m 224\u001b[0m url_or_filename \u001b[38;5;241m=\u001b[39m url_or_path_join(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_base_path, url_or_filename)\n\u001b[0;32m--> 225\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mcached_path\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl_or_filename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 226\u001b[0m out \u001b[38;5;241m=\u001b[39m tracked_str(out)\n\u001b[1;32m 227\u001b[0m out\u001b[38;5;241m.\u001b[39mset_origin(url_or_filename)\n",
|
90 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/file_utils.py:205\u001b[0m, in \u001b[0;36mcached_path\u001b[0;34m(url_or_filename, download_config, **download_kwargs)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\u001b[38;5;28mstr\u001b[39m(e)) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;66;03m# Download external files\u001b[39;00m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 205\u001b[0m output_path \u001b[38;5;241m=\u001b[39m \u001b[43mget_from_cache\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[43m \u001b[49m\u001b[43murl_or_filename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 207\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 208\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 209\u001b[0m \u001b[43m \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 210\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_etag\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_etag\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 211\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 212\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 213\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_desc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload_desc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 214\u001b[0m \u001b[43m \u001b[49m\u001b[43mdisable_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdisable_tqdm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 215\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mexists(url_or_filename):\n\u001b[1;32m 217\u001b[0m \u001b[38;5;66;03m# File, and it exists.\u001b[39;00m\n\u001b[1;32m 218\u001b[0m output_path \u001b[38;5;241m=\u001b[39m url_or_filename\n",
|
91 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/file_utils.py:415\u001b[0m, in \u001b[0;36mget_from_cache\u001b[0;34m(url, cache_dir, force_download, user_agent, use_etag, token, storage_options, download_desc, disable_tqdm)\u001b[0m\n\u001b[1;32m 413\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00murl\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m not found in cache or force_download set to True, downloading to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtemp_file\u001b[38;5;241m.\u001b[39mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 414\u001b[0m \u001b[38;5;66;03m# GET file object\u001b[39;00m\n\u001b[0;32m--> 415\u001b[0m \u001b[43mfsspec_get\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemp_file\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_desc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisable_tqdm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdisable_tqdm\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 417\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstoring \u001b[39m\u001b[38;5;132;01m{\u001b[39;00murl\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m in cache at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcache_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 418\u001b[0m shutil\u001b[38;5;241m.\u001b[39mmove(temp_file\u001b[38;5;241m.\u001b[39mname, cache_path)\n",
|
92 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/datasets/utils/file_utils.py:334\u001b[0m, in \u001b[0;36mfsspec_get\u001b[0;34m(url, temp_file, storage_options, desc, disable_tqdm)\u001b[0m\n\u001b[1;32m 321\u001b[0m fs, path \u001b[38;5;241m=\u001b[39m url_to_fs(url, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m(storage_options \u001b[38;5;129;01mor\u001b[39;00m {}))\n\u001b[1;32m 322\u001b[0m callback \u001b[38;5;241m=\u001b[39m TqdmCallback(\n\u001b[1;32m 323\u001b[0m tqdm_kwargs\u001b[38;5;241m=\u001b[39m{\n\u001b[1;32m 324\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdesc\u001b[39m\u001b[38;5;124m\"\u001b[39m: desc \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDownloading\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 332\u001b[0m }\n\u001b[1;32m 333\u001b[0m )\n\u001b[0;32m--> 334\u001b[0m \u001b[43mfs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemp_file\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m)\u001b[49m\n",
|
93 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/fsspec/asyn.py:118\u001b[0m, in \u001b[0;36msync_wrapper.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28mself\u001b[39m \u001b[38;5;241m=\u001b[39m obj \u001b[38;5;129;01mor\u001b[39;00m args[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m--> 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43msync\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloop\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
94 |
+
"File \u001b[0;32m~/google_mlb/gemma/gemmarte/.venv/lib/python3.12/site-packages/fsspec/asyn.py:101\u001b[0m, in \u001b[0;36msync\u001b[0;34m(loop, func, timeout, *args, **kwargs)\u001b[0m\n\u001b[1;32m 98\u001b[0m return_result \u001b[38;5;241m=\u001b[39m result[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 99\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(return_result, asyncio\u001b[38;5;241m.\u001b[39mTimeoutError):\n\u001b[1;32m 100\u001b[0m \u001b[38;5;66;03m# suppress asyncio.TimeoutError, raise FSTimeoutError\u001b[39;00m\n\u001b[0;32m--> 101\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m FSTimeoutError \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mreturn_result\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(return_result, \u001b[38;5;167;01mBaseException\u001b[39;00m):\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m return_result\n",
|
95 |
+
"\u001b[0;31mFSTimeoutError\u001b[0m: "
|
96 |
+
]
|
97 |
+
}
|
98 |
+
],
|
99 |
+
"source": [
|
100 |
+
"from datasets import load_dataset \n",
|
101 |
+
"ds = load_dataset('HuggingFaceM4/VQAv2', split=\"train\", trust_remote_code=True) \n",
|
102 |
+
"cols_remove = [\"question_type\", \"answers\", \"answer_type\", \"image_id\", \"question_id\"] \n",
|
103 |
+
"ds = ds.remove_columns(cols_remove)\n",
|
104 |
+
"ds = ds.train_test_split(test_size=0.1)\n",
|
105 |
+
"train_ds = ds[\"train\"]\n",
|
106 |
+
"val_ds = ds[\"test\"]"
|
107 |
+
]
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"cell_type": "markdown",
|
111 |
+
"metadata": {},
|
112 |
+
"source": [
|
113 |
+
"# Train (QLoRA 4-bit)"
|
114 |
+
]
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"cell_type": "code",
|
118 |
+
"execution_count": null,
|
119 |
+
"metadata": {},
|
120 |
+
"outputs": [],
|
121 |
+
"source": [
|
122 |
+
"from transformers import PaliGemmaProcessor \n",
|
123 |
+
"model_id = \"google/paligemma-3b-pt-224\"\n",
|
124 |
+
"processor = PaliGemmaProcessor.from_pretrained(model_id)"
|
125 |
+
]
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"cell_type": "code",
|
129 |
+
"execution_count": null,
|
130 |
+
"metadata": {},
|
131 |
+
"outputs": [],
|
132 |
+
"source": [
|
133 |
+
"import torch\n",
|
134 |
+
"device = \"cuda\"\n",
|
135 |
+
"\n",
|
136 |
+
"image_token = processor.tokenizer.convert_tokens_to_ids(\"<image>\")\n",
|
137 |
+
"def collate_fn(examples):\n",
|
138 |
+
" texts = [\"answer \" + example[\"question\"] for example in examples]\n",
|
139 |
+
" labels= [example['multiple_choice_answer'] for example in examples] # μ°λ¦¬λ label μ΄ νμ μμλ―?\n",
|
140 |
+
" images = [example[\"image\"].convert(\"RGB\") for example in examples]\n",
|
141 |
+
" tokens = processor(text=texts, images=images, suffix=labels,\n",
|
142 |
+
" return_tensors=\"pt\", padding=\"longest\")\n",
|
143 |
+
"\n",
|
144 |
+
" tokens = tokens.to(torch.bfloat16).to(device)\n",
|
145 |
+
" return tokens"
|
146 |
+
]
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"cell_type": "code",
|
150 |
+
"execution_count": null,
|
151 |
+
"metadata": {},
|
152 |
+
"outputs": [],
|
153 |
+
"source": [
|
154 |
+
"from transformers import PaliGemmaForConditionalGeneration\n",
|
155 |
+
"model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)\n",
|
156 |
+
"\n",
|
157 |
+
"# Freeze the vision tower and the transformer encoder (image encoder)\n",
|
158 |
+
"for param in model.vision_tower.parameters():\n",
|
159 |
+
" param.requires_grad = False\n",
|
160 |
+
"\n",
|
161 |
+
"# Projector is not frozen (Article μμλ freeze νλ€κ³ λμ΄μμ)\n",
|
162 |
+
"for param in model.multi_modal_projector.parameters():\n",
|
163 |
+
" param.requires_grad = True\n"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "markdown",
|
168 |
+
"metadata": {},
|
169 |
+
"source": [
|
170 |
+
"For QLoRa in 4-bit"
|
171 |
+
]
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"cell_type": "code",
|
175 |
+
"execution_count": null,
|
176 |
+
"metadata": {},
|
177 |
+
"outputs": [],
|
178 |
+
"source": [
|
179 |
+
"from transformers import BitsAndBytesConfig\n",
|
180 |
+
"from peft import get_peft_model, LoraConfig\n",
|
181 |
+
"\n",
|
182 |
+
"bnb_config = BitsAndBytesConfig(\n",
|
183 |
+
" load_in_4bit=True,\n",
|
184 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
185 |
+
" bnb_4bit_compute_type=torch.bfloat16\n",
|
186 |
+
")\n",
|
187 |
+
"\n",
|
188 |
+
"lora_config = LoraConfig(\n",
|
189 |
+
" r=8, \n",
|
190 |
+
" target_modules=[\"q_proj\", \"o_proj\", \"k_proj\", \"v_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
191 |
+
" task_type=\"CAUSAL_LM\",\n",
|
192 |
+
")\n",
|
193 |
+
"model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={\"\":0})\n",
|
194 |
+
"model = get_peft_model(model, lora_config)\n",
|
195 |
+
"model.print_trainable_parameters()\n",
|
196 |
+
"#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344\n"
|
197 |
+
]
|
198 |
+
},
|
199 |
+
{
|
200 |
+
"cell_type": "code",
|
201 |
+
"execution_count": null,
|
202 |
+
"metadata": {},
|
203 |
+
"outputs": [],
|
204 |
+
"source": [
|
205 |
+
"from transformers import TrainingArguments\n",
|
206 |
+
"args=TrainingArguments(\n",
|
207 |
+
" num_train_epochs=2,\n",
|
208 |
+
" remove_unused_columns=False,\n",
|
209 |
+
" per_device_train_batch_size=16,\n",
|
210 |
+
" gradient_accumulation_steps=4,\n",
|
211 |
+
" warmup_steps=2,\n",
|
212 |
+
" learning_rate=2e-5,\n",
|
213 |
+
" weight_decay=1e-6,\n",
|
214 |
+
" adam_beta2=0.999,\n",
|
215 |
+
" logging_steps=100,\n",
|
216 |
+
" # optim=\"adamw_hf\",\n",
|
217 |
+
" optim=\"paged_adamw_8bit\", # for QLoRA\n",
|
218 |
+
" save_strategy=\"steps\",\n",
|
219 |
+
" save_steps=1000,\n",
|
220 |
+
" push_to_hub=True,\n",
|
221 |
+
" save_total_limit=1,\n",
|
222 |
+
" bf16=True,\n",
|
223 |
+
" report_to=[\"tensorboard\"],\n",
|
224 |
+
" dataloader_pin_memory=False\n",
|
225 |
+
" )\n"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "code",
|
230 |
+
"execution_count": null,
|
231 |
+
"metadata": {},
|
232 |
+
"outputs": [],
|
233 |
+
"source": [
|
234 |
+
"from transformers import Trainer\n",
|
235 |
+
"trainer = Trainer(\n",
|
236 |
+
" model=model,\n",
|
237 |
+
" train_dataset=train_ds,\n",
|
238 |
+
" eval_dataset=val_ds,\n",
|
239 |
+
" data_collator=collate_fn,\n",
|
240 |
+
" args=args\n",
|
241 |
+
" )"
|
242 |
+
]
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"cell_type": "code",
|
246 |
+
"execution_count": null,
|
247 |
+
"metadata": {},
|
248 |
+
"outputs": [],
|
249 |
+
"source": [
|
250 |
+
"trainer.train()"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "markdown",
|
255 |
+
"metadata": {},
|
256 |
+
"source": [
|
257 |
+
"# Inference for test"
|
258 |
+
]
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "code",
|
262 |
+
"execution_count": null,
|
263 |
+
"metadata": {},
|
264 |
+
"outputs": [],
|
265 |
+
"source": []
|
266 |
+
}
|
267 |
+
],
|
268 |
+
"metadata": {
|
269 |
+
"kernelspec": {
|
270 |
+
"display_name": "Python 3",
|
271 |
+
"language": "python",
|
272 |
+
"name": "python3"
|
273 |
+
},
|
274 |
+
"language_info": {
|
275 |
+
"codemirror_mode": {
|
276 |
+
"name": "ipython",
|
277 |
+
"version": 3
|
278 |
+
},
|
279 |
+
"file_extension": ".py",
|
280 |
+
"mimetype": "text/x-python",
|
281 |
+
"name": "python",
|
282 |
+
"nbconvert_exporter": "python",
|
283 |
+
"pygments_lexer": "ipython3",
|
284 |
+
"version": "3.12.2"
|
285 |
+
}
|
286 |
+
},
|
287 |
+
"nbformat": 4,
|
288 |
+
"nbformat_minor": 2
|
289 |
+
}
|