pizb commited on
Commit
ee63d12
β€’
2 Parent(s): ece63b7 1758c0c

Merge branch 'train-baseline'

Browse files
.gitignore CHANGED
@@ -1 +1,4 @@
1
- .venv
 
 
 
 
1
+ .venv
2
+ dataset
3
+ output
4
+ big_vision_repo
Finetune_PaliGemma_for_image_description.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Readme.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset Structure
2
+
3
+ /custom_vqa_project/
4
+ β”‚
5
+ β”œβ”€β”€ /dataset/
6
+ β”‚ β”œβ”€β”€ /images/
7
+ β”‚ β”‚ β”œβ”€β”€ train/
8
+ β”‚ β”‚ β”‚ β”œβ”€β”€ image1.jpg
9
+ β”‚ β”‚ β”‚ β”œβ”€β”€ image2.jpg
10
+ β”‚ β”‚ └── val/
11
+ β”‚ β”‚ β”œβ”€β”€ image3.jpg
12
+ β”‚ β”‚ └── image4.jpg
13
+ β”‚ β”œβ”€β”€ train.json # Metadata for the training set
14
+ β”‚ └── val.json # Metadata for the validation set
15
+ β”‚
16
+ β”œβ”€β”€ /scripts/
17
+ β”‚ └── train.py # Your fine-tuning script
18
+ β”‚
19
+ └── README.md
article_base_train.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, math
2
+ import pandas as pd
3
+ from datasets import Dataset
4
+ from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig, TrainingArguments, Trainer
5
+ import torch
6
+ from PIL import Image
7
+ from peft import get_peft_model, LoraConfig
8
+ import argparse
9
+
10
+
11
+ # Function to load custom dataset from CSV
12
+ def load_custom_dataset_from_csv(csv_file, image_folder):
13
+ # Load CSV data using pandas
14
+ data = pd.read_csv(csv_file)
15
+
16
+ # Prepare dataset format for Hugging Face
17
+ questions = data['question'].tolist()
18
+ images = [os.path.join(image_folder, img) for img in data['image'].tolist()]
19
+ answers = data['answer'].tolist()
20
+
21
+ # Create a Hugging Face dataset from the loaded CSV
22
+ return Dataset.from_dict({
23
+ 'question': questions,
24
+ 'image': images,
25
+ 'answer': answers
26
+ })
27
+
28
+
29
+ # Function to load custom dataset from Parquet
30
+ def load_custom_dataset_from_parquet(parquet_file, image_folder):
31
+ # Load Parquet data using pandas
32
+ data = pd.read_parquet(parquet_file)
33
+
34
+ # Prepare dataset format for Hugging Face
35
+ questions = data['question'].tolist()
36
+ images = [os.path.join(image_folder, img) for img in data['image'].tolist()]
37
+ answers = data['answer'].tolist()
38
+
39
+ # Create a Hugging Face dataset from the loaded Parquet
40
+ return Dataset.from_dict({
41
+ 'question': questions,
42
+ 'image': images,
43
+ 'answer': answers
44
+ })
45
+
46
+
47
+ # Choose the appropriate loader based on metadata_type argument
48
+ def load_dataset_by_type(metadata_type, dataset_dir, image_folder):
49
+ if metadata_type == "csv":
50
+ return load_custom_dataset_from_csv(
51
+ os.path.join(dataset_dir, 'train_samples.csv'),
52
+ image_folder
53
+ )
54
+ elif metadata_type == "parquet":
55
+ return load_custom_dataset_from_parquet(
56
+ os.path.join(dataset_dir, 'train.parquet'),
57
+ image_folder
58
+ )
59
+ else:
60
+ raise ValueError("Unsupported metadata type. Use 'csv' or 'parquet'.")
61
+
62
+
63
+ def load_model_and_args(use_qlora, model_id, device, output_dir):
64
+ if use_qlora:
65
+ bnb_config = BitsAndBytesConfig(
66
+ load_in_4bit=True,
67
+ bnb_4bit_quant_type="nf4",
68
+ bnb_4bit_compute_dtype=torch.bfloat16
69
+ )
70
+ lora_config = LoraConfig(
71
+ r=8,
72
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
73
+ task_type="CAUSAL_LM"
74
+ )
75
+
76
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={"": 0})
77
+ model = get_peft_model(model, lora_config)
78
+ model.print_trainable_parameters()
79
+
80
+ # TODO: Customize training setting
81
+ args = TrainingArguments(
82
+ output_dir=os.path.join(output_dir, f"{math.floor(time.time())}"),
83
+ num_train_epochs=2,
84
+ remove_unused_columns=False,
85
+ per_device_train_batch_size=1,
86
+ gradient_accumulation_steps=4,
87
+ warmup_steps=2,
88
+ learning_rate=2e-5,
89
+ weight_decay=1e-6,
90
+ logging_steps=100,
91
+ optim="adamw_hf",
92
+ save_strategy="steps",
93
+ save_steps=1000,
94
+ save_total_limit=1,
95
+ bf16=True,
96
+ report_to=["tensorboard"],
97
+ dataloader_pin_memory=False
98
+ )
99
+
100
+ return model, args
101
+ else:
102
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
103
+ for param in model.vision_tower.parameters():
104
+ param.requires_grad = False
105
+
106
+ for param in model.multi_modal_projector.parameters():
107
+ param.requires_grad = True
108
+
109
+ # TODO: Customize training setting
110
+ args = TrainingArguments(
111
+ output_dir=os.path.join(output_dir, f"{math.floor(time.time())}"),
112
+ num_train_epochs=2,
113
+ remove_unused_columns=False,
114
+ per_device_train_batch_size=4,
115
+ gradient_accumulation_steps=4,
116
+ warmup_steps=2,
117
+ learning_rate=2e-5,
118
+ weight_decay=1e-6,
119
+ logging_steps=100,
120
+ optim="paged_adamw_8bit",
121
+ save_strategy="steps",
122
+ save_steps=1000,
123
+ save_total_limit=1,
124
+ bf16=True,
125
+ report_to=["tensorboard"],
126
+ dataloader_pin_memory=False
127
+ )
128
+
129
+ return model, args
130
+
131
+
132
+ # Main training function
133
+ def main(args):
134
+ dataset_dir = args.dataset_dir
135
+ model_id = args.model_id
136
+ output_dir = args.output_dir
137
+ metadata_type = args.metadata_type
138
+
139
+ # Load custom datasetsγ„΄
140
+ # dataset = load_custom_dataset_from_csv(
141
+ # os.path.join(dataset_dir, 'train_samples.csv'),
142
+ # os.path.join(dataset_dir, 'images/train')) # TODO: change to appropriate path
143
+ dataset = load_dataset_by_type(metadata_type, dataset_dir, os.path.join(dataset_dir, 'images/train'))
144
+ train_val_split = dataset.train_test_split(test_size=0.1)
145
+
146
+ train_ds = train_val_split['train']
147
+ val_ds = train_val_split['test']
148
+
149
+ processor = PaliGemmaProcessor.from_pretrained(model_id)
150
+ device = "cuda"
151
+
152
+ model, args = load_model_and_args(args.use_qlora, model_id, device, output_dir)
153
+
154
+ # Custom collate function
155
+ def collate_fn(examples):
156
+ texts = [example["question"] for example in examples]
157
+ labels = [example['answer'] for example in examples]
158
+ images = [Image.open(example['image']).convert("RGB") for example in examples]
159
+ tokens = processor(text=texts, images=images, suffix=labels, return_tensors="pt", padding="longest")
160
+ tokens = tokens.to(torch.bfloat16).to(device)
161
+ return tokens
162
+
163
+ trainer = Trainer(
164
+ model=model,
165
+ train_dataset=train_ds,
166
+ eval_dataset=val_ds,
167
+ data_collator=collate_fn,
168
+ args=args
169
+ )
170
+
171
+ trainer.train()
172
+
173
+
174
+ def parse_args():
175
+ parser = argparse.ArgumentParser(description="Train a model with custom dataset")
176
+ parser.add_argument('--dataset_dir', type=str, default='./dataset', help='Path to the folder containing the images')
177
+ parser.add_argument('--model_id', type=str, default='google/paligemma-3b-pt-224', help='Model ID to use for training')
178
+ parser.add_argument('--output_dir', type=str, default='./output', help='Directory to save the output')
179
+ parser.add_argument('--use_qlora', type=bool, default=False, help='Use QLoRA for training')
180
+ parser.add_argument('--metadata_type', type=str, default='parquet', choices=['csv', 'parquet'], help='Metadata format (csv or parquet)')
181
+ return parser.parse_args()
182
+
183
+
184
+ if __name__ == "__main__":
185
+ args = parse_args()
186
+ main(args)
article_base_train_test.py DELETED
@@ -1,80 +0,0 @@
1
- from huggingface_hub import notebook_login
2
- from datasets import load_dataset
3
- from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig, TrainingArguments, Trainer
4
- import torch
5
- from peft import get_peft_model, LoraConfig
6
-
7
-
8
- def main():
9
- ds = load_dataset('HuggingFaceM4/VQAv2', split="train", trust_remote_code=True)
10
- cols_remove = ["question_type", "answers", "answer_type", "image_id", "question_id"]
11
- ds = ds.remove_columns(cols_remove)
12
- ds = ds.train_test_split(test_size=0.1)
13
- train_ds = ds["train"]
14
- val_ds = ds["test"]
15
-
16
- model_id = "google/paligemma-3b-pt-224"
17
- processor = PaliGemmaProcessor.from_pretrained(model_id)
18
- image_token = processor.tokenizer.convert_tokens_to_ids("<image>")
19
- device = "cuda"
20
-
21
- bnb_config = BitsAndBytesConfig(
22
- load_in_4bit=True,
23
- bnb_4bit_quant_type="nf4",
24
- bnb_4bit_compute_type=torch.bfloat16
25
- )
26
- lora_config = LoraConfig(
27
- r=8,
28
- target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
29
- task_type="CAUSAL_LM",
30
- )
31
- model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
32
- model = get_peft_model(model, lora_config)
33
- model.print_trainable_parameters()
34
- #trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344
35
-
36
- args=TrainingArguments(
37
- num_train_epochs=2,
38
- remove_unused_columns=False,
39
- per_device_train_batch_size=16,
40
- gradient_accumulation_steps=4,
41
- warmup_steps=2,
42
- learning_rate=2e-5,
43
- weight_decay=1e-6,
44
- adam_beta2=0.999,
45
- logging_steps=100,
46
- # optim="adamw_hf",
47
- optim="paged_adamw_8bit", # for QLoRA
48
- save_strategy="steps",
49
- save_steps=1000,
50
- push_to_hub=True,
51
- save_total_limit=1,
52
- bf16=True,
53
- report_to=["tensorboard"],
54
- dataloader_pin_memory=False
55
- )
56
-
57
- def collate_fn(examples):
58
- texts = ["answer " + example["question"] for example in examples]
59
- labels= [example['multiple_choice_answer'] for example in examples] # μš°λ¦¬λŠ” label 이 ν•„μš” 없을듯?
60
- images = [example["image"].convert("RGB") for example in examples]
61
- tokens = processor(text=texts, images=images, suffix=labels,
62
- return_tensors="pt", padding="longest")
63
-
64
- tokens = tokens.to(torch.bfloat16).to(device)
65
- return tokens
66
-
67
- trainer = Trainer(
68
- model=model,
69
- train_dataset=train_ds,
70
- eval_dataset=val_ds,
71
- data_collator=collate_fn,
72
- args=args
73
- )
74
-
75
- trainer.train()
76
-
77
-
78
- if __name__ == "__main__":
79
- notebook_login()
80
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
article_base_tutorial.ipynb CHANGED
@@ -254,7 +254,7 @@
254
  "cell_type": "markdown",
255
  "metadata": {},
256
  "source": [
257
- "# Inference for test"
258
  ]
259
  },
260
  {
@@ -262,7 +262,20 @@
262
  "execution_count": null,
263
  "metadata": {},
264
  "outputs": [],
265
- "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  }
267
  ],
268
  "metadata": {
 
254
  "cell_type": "markdown",
255
  "metadata": {},
256
  "source": [
257
+ "Save Model"
258
  ]
259
  },
260
  {
 
262
  "execution_count": null,
263
  "metadata": {},
264
  "outputs": [],
265
+ "source": [
266
+ "save_path = \"./fine_tuned_model\"\n",
267
+ "model.save_pretrained(save_path)\n",
268
+ "processor.save_pretrained(save_path)\n",
269
+ "\n",
270
+ "print(f\"Model saved locally at {save_path}\")"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "markdown",
275
+ "metadata": {},
276
+ "source": [
277
+ "# Inference for test"
278
+ ]
279
  }
280
  ],
281
  "metadata": {
requirements.txt ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.34.2
2
+ aiohappyeyeballs==2.4.2
3
+ aiohttp==3.10.6
4
+ aiosignal==1.3.1
5
+ appnope==0.1.4
6
+ asttokens==2.4.1
7
+ attrs==24.2.0
8
+ certifi==2024.8.30
9
+ charset-normalizer==3.3.2
10
+ comm==0.2.2
11
+ datasets==3.0.1
12
+ debugpy==1.8.6
13
+ decorator==5.1.1
14
+ dill==0.3.8
15
+ executing==2.1.0
16
+ filelock==3.16.1
17
+ frozenlist==1.4.1
18
+ fsspec==2024.6.1
19
+ huggingface-hub==0.25.1
20
+ idna==3.10
21
+ ipykernel==6.29.5
22
+ ipython==8.27.0
23
+ ipywidgets==8.1.5
24
+ jedi==0.19.1
25
+ Jinja2==3.1.4
26
+ jupyter_client==8.6.3
27
+ jupyter_core==5.7.2
28
+ jupyterlab_widgets==3.0.13
29
+ MarkupSafe==2.1.5
30
+ matplotlib-inline==0.1.7
31
+ mpmath==1.3.0
32
+ multidict==6.1.0
33
+ multiprocess==0.70.16
34
+ nest-asyncio==1.6.0
35
+ networkx==3.3
36
+ numpy==2.1.1
37
+ packaging==24.1
38
+ pandas==2.2.3
39
+ parso==0.8.4
40
+ peft==0.13.0
41
+ pexpect==4.9.0
42
+ pillow==10.4.0
43
+ pip==24.0
44
+ platformdirs==4.3.6
45
+ prompt_toolkit==3.0.48
46
+ psutil==6.0.0
47
+ ptyprocess==0.7.0
48
+ pure_eval==0.2.3
49
+ pyarrow==17.0.0
50
+ Pygments==2.18.0
51
+ python-dateutil==2.9.0.post0
52
+ pytz==2024.2
53
+ PyYAML==6.0.2
54
+ pyzmq==26.2.0
55
+ regex==2024.9.11
56
+ requests==2.32.3
57
+ safetensors==0.4.5
58
+ setuptools==75.1.0
59
+ six==1.16.0
60
+ stack-data==0.6.3
61
+ sympy==1.13.3
62
+ tokenizers==0.20.0
63
+ torch==2.4.1
64
+ tornado==6.4.1
65
+ tqdm==4.66.5
66
+ traitlets==5.14.3
67
+ transformers==4.45.1
68
+ typing_extensions==4.12.2
69
+ tzdata==2024.2
70
+ urllib3==2.2.3
71
+ wcwidth==0.2.13
72
+ widgetsnbextension==4.0.13
73
+ xxhash==3.5.0
74
+ yarl==1.13.0
test_inference.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
2
+ from PIL import Image
3
+
4
+
5
+ def main():
6
+ model_id = "google/paligemma-3b-pt-224"
7
+ # model_path = "output/1727488022/checkpoint-112"
8
+ model_path = "output/1727490265/checkpoint-450"
9
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_path)
10
+ processor = AutoProcessor.from_pretrained(model_id)
11
+
12
+ # prompt = "Analyze image from a critic's point of view."
13
+ prompt = "Please construct a formal analysis paragraph that is coherent and focuses solely on visual characteristic."
14
+ image_file_path = "dataset/images/manual_test/starry_night.jpg"
15
+ raw_image = Image.open(image_file_path)
16
+ inputs = processor(prompt, raw_image, return_tensors="pt")
17
+ output = model.generate(**inputs, max_new_tokens=20)
18
+
19
+ # Starry Night
20
+ print("Response: ", processor.decode(output[0], skip_special_tokens=True)[len(prompt):])
21
+
22
+
23
+ if __name__ == "__main__":
24
+ main()