Image-Text-to-Text
Transformers
Safetensors
English
idefics2
pretraining
multimodal
vision
Inference Endpoints
5 papers

Multi-gpu fine-tuning

#30
by matbee - opened

Is it possible to utilize FSDP / Deepspeed for finetuning this?

HuggingFaceM4 org

Hi @matbee , yes it is totally possible.
One low-barrier entry point to do so is through the HF trainer: https://huggingface.co/docs/transformers/main/en/trainer
you can for instance adapt https://colab.research.google.com/drive/1rm3AGquGEYXfeeizE40bbDtcWh5S4Nlq?authuser=1#scrollTo=nlEpIG4UBmoH which provides some code to fine-tune on a single GPU. the main work would be to adapt the accelerate config.
let me know if you need reviews/help!

Here's what I've got so far. It IS training, unsure if it's absolutely correct though. I had to set 'mixed_precision' to no in the config.

I dont think its properly sharding it, as when training via 1 gpu, it uses the same amount of VRAM on each.

import torch
import random
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoProcessor, AutoTokenizer, BitsAndBytesConfig, Idefics2ForConditionalGeneration, TrainingArguments, Trainer
from datasets import load_dataset
from accelerate import PartialState

DEVICE = "cuda:0"
USE_LORA = False
USE_QLORA = True

processor = AutoProcessor.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    do_image_splitting=False
)

# Three options for training, from the lowest precision training to the highest precision training:
# - QLora
# - Standard Lora
# - Full fine-tuning
IDEFICS2_CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""

if USE_QLORA or USE_LORA:
    peft_config = LoraConfig(
        r=8,
        lora_alpha=8,
        lora_dropout=0.1,
        target_modules='.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$',
        init_lora_weights="gaussian"
    )
    if USE_QLORA:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_quant_storage=torch.bfloat16,
        )
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceM4/idefics2-8b", use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.chat_template = IDEFICS2_CHAT_TEMPLATE

    model = Idefics2ForConditionalGeneration.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        torch_dtype=torch.bfloat16,
        quantization_config=bnb_config if USE_QLORA else None,
        attn_implementation="flash_attention_2",
        device_map={"": PartialState().process_index}
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    if USE_LORA:
        model = prepare_model_for_kbit_training(model)
    # model.add_adapter(peft_config)
    # model.enable_adapters()
else:
    model = Idefics2ForConditionalGeneration.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        torch_dtype=torch.bfloat16,
        _attn_implementation="flash_attention_2", # Only available on A100 or H100
    ).to(DEVICE)

### Load Dataset

train_dataset = load_dataset("nielsr/docvqa_1200_examples", split="train")
train_dataset = train_dataset.remove_columns(['id', 'words', 'bounding_boxes', 'answer'])

eval_dataset = load_dataset("nielsr/docvqa_1200_examples", split="test")
eval_dataset = eval_dataset.remove_columns(['id', 'words', 'bounding_boxes', 'answer'])

#### Dataset Formatting

class MyDataCollator:
    def __init__(self, processor):
        self.processor = processor
        self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
            processor.tokenizer.additional_special_tokens.index("<image>")
        ]

    def __call__(self, examples):
        texts = []
        images = []
        for example in examples:
            image = example["image"]
            question = example["query"]["en"]
            answer = random.choice(example["answers"])
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Answer briefly."},
                        {"type": "image"},
                        {"type": "text", "text": question}
                    ]
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": answer}
                    ]
                }
            ]
            text = processor.apply_chat_template(messages, add_generation_prompt=False)
            texts.append(text.strip())
            images.append([image])

        batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id
        batch["labels"] = labels

        return batch

data_collator = MyDataCollator(processor)

### Training

training_args = TrainingArguments(
    num_train_epochs=2,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=8,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=25,
    output_dir="./docvqa_ft_tutorial",
    save_strategy="steps",
    save_steps=250,
    save_total_limit=1,
    # evaluation_strategy="epoch",
    remove_unused_columns=False,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
)

trainer.train()

if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
compute_environment: LOCAL_MACHINE                                                                                                                                           
debug: false                                                                                                                                                                 
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --multi_gpu --config_file fsdp_config_qlora.yaml idefics2.py --train_type qlora --use_flash_attn true --use_peft_lora True --use_reentrant True --use_4bit_quantization True --bf16 true

Yeah I think this might be beyond my current abilities

HuggingFaceM4 org

i'll allocate some time to dig in, looking at the config and the script, you should be pretty close I think

I believe one source of problems is that I need to parse the arguments with hf_argparser and pass those values into Trainer({args:...}).
The current state DOES seem to work with DDP, but not FSDP. It's likely some combination of arguments I'm using. It's very close though.

HuggingFaceM4 org

if it is of any help in the meantime, here's the config I used to train on multiple gpus with deepspeed (not fsdp)
I don't think it matters but I passed all the parameters inside the TrainingArgs.

compute_environment: LOCAL_MACHINE
deepspeed_config:
  deepspeed_multinode_launcher: standard
  deepspeed_config_file: deepspeed_config.json
  zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: $MASTER_ADDR
main_process_port: $MASTER_PORT
main_training_function: main
num_machines: 1
num_processes: $NUM_GPUS
use_cpu: false

and the content of deepspeed_config.json

{
    "communication_data_type": "fp32",
    "zero_optimization": {
        "stage": 3,
        "offload_param": {
            "device": "none"
        },
        "offload_optimizer": {
            "device": "none"
        }
    },
    "fp16": {
        "enabled": false
    },
    "bf16": {
        "enabled": true
    },
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": "auto"
}

the training script:

import torch
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics2ForConditionalGeneration
import safetensors

DEVICE = "cuda:0"
USE_4_BIT = False
RESUME_FROM_CHECKPOINT = False

processor = AutoProcessor.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    do_image_splitting=True,
)
if USE_4_BIT:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    model = Idefics2ForConditionalGeneration.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        torch_dtype=torch.bfloat16,
        quantization_config=bnb_config,
    )
else:
    model = Idefics2ForConditionalGeneration.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        torch_dtype=torch.bfloat16,
        _attn_implementation="flash_attention_2",
    )#.to(DEVICE)

##
from peft import LoraConfig
from peft import get_peft_model

lora_config = LoraConfig(
    r=4,
    lora_alpha=4,
    lora_dropout=0.1,
    target_modules='all-linear',
    use_dora=True,
    init_lora_weights="gaussian"
)

model = get_peft_model(model, lora_config)


##
from datasets import load_dataset, disable_caching
disable_caching()

train_dataset = load_dataset("HuggingFaceM4/DocumentVQA", split="train") # TO CHANGE with nielsr/docvqa_1200_examples_donut
train_dataset = train_dataset.remove_columns(['questionId', 'question_types', 'docId', 'ucsf_document_id', 'ucsf_document_page_no'])
eval_dataset = load_dataset("HuggingFaceM4/DocumentVQA", split="validation") # TO CHANGE with nielsr/docvqa_1200_examples_donut
eval_dataset = eval_dataset.remove_columns(['questionId', 'question_types', 'docId', 'ucsf_document_id', 'ucsf_document_page_no'])

##
import random

class MyDataCollator:
    def __init__(self, processor):
        self.processor = processor
        self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
            processor.tokenizer.additional_special_tokens.index("<image>")
        ]

    def __call__(self, examples):
        texts = []
        images = []
        for example in examples:
            image = example["image"]
            if image is None:
                continue
            question = example["question"]
            answer = random.choice(example["answers"])
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Answer briefly."},
                        {"type": "image"},
                        {"type": "text", "text": question}
                    ]
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": answer}
                    ]
                }
            ]
            text = processor.apply_chat_template(messages, add_generation_prompt=False)
            texts.append(text.strip())
            images.append([image])

        batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id
        batch["labels"] = labels

        return batch

data_collator = MyDataCollator(processor)

##
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    warmup_steps=100,
    learning_rate=5e-5,
    weight_decay=0.1,
    logging_steps=10,
    output_dir="./docvqa_ft_tutorial",
    save_strategy="steps",
    save_steps=100,
    save_total_limit=3,
    bf16=True,
    push_to_hub_model_id="test-victor",
    remove_unused_columns=False,
    report_to="none",
    deepspeed="deepspeed_config.json",
    save_safetensors=False,
    neftune_noise_alpha=5.0,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
)

trainer.train(resume_from_checkpoint=RESUME_FROM_CHECKPOINT)

trainer.push_to_hub()

the launch command:

accelerate launch \
    --rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
    --config_file $ACCELERATE_CONFIG_FILE \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --machine_rank \$SLURM_PROCID \
    --role \$(hostname -s): --tee 3 \
    docvqa_ft.py

Alright I believe I've been able to take what you gave and get a working Zero3 deepspeed fine-tune. Would definitely need a little bit of lovin' to make it universal/shippable- but it seems to work!

https://gist.github.com/matbee-eth/466ec56c9fc82a15ac7ea0a1ba5df29c

HuggingFaceM4 org

let's go! i am glad this unblocked you!

By the way, how can I train it with AWS Sagemaker? Do I need to do any changes in the code?

By the way, how can I train it with AWS Sagemaker? Do I need to do any changes in the code?

Ive never used it, but a quick peruse looks like it should work. It doesn't use accelerate+deepspeed config, it uses its own accelerate+ sagemaker config. Atleast from my inexperienced view, it should be worth a shot.

i'm trying to finetune on a multi-GPU node with 2 NVIDIA A100 GPUs, using the training script provided by Victor.

getting a runtime error indicating a mismatch in device allocation between cuda:1 and cuda:0.

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

 [WARNING]  using untested triton version (2.3.0), only 1.0.0 is known to be compatible
[2024-04-29 01:59:03,406] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-04-29 01:59:03,406] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2024-04-29 01:59:03,485] [INFO] [comm.py:637:init_distributed] cdb=None
Parameter Offload: Total persistent parameters: 30988016 in 1060 params
  0%|                                                                                                                                                  | 0/40000 [00:00<?, ?it/s]
No chat template is defined for this tokenizer - using the default template for the LlamaTokenizerFast class. If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.


No chat template is defined for this tokenizer - using the default template for the LlamaTokenizerFast class. If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/idefics2/deepspeed/idefics2_deepspeed_script.py", line 180, in <module>
[rank1]:     trainer.train()
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
[rank1]:     return inner_training_loop(
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
[rank1]:     tr_loss_step = self.training_step(model, inputs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/transformers/trainer.py", line 3138, in training_step
[rank1]:     loss = self.compute_loss(model, inputs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/transformers/trainer.py", line 3161, in compute_loss
[rank1]:     outputs = model(**inputs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
[rank1]:     ret_val = func(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1855, in forward
[rank1]:     loss = self.module(*inputs, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/peft/peft_model.py", line 563, in forward
[rank1]:     return self.get_base_model()(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
[rank1]:     output = module._old_forward(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1823, in forward
[rank1]:     outputs = self.model(
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
[rank1]:     output = module._old_forward(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1602, in forward
[rank1]:     inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1582, in _call_impl
[rank1]:     result = forward_call(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
[rank1]:     output = module._old_forward(*args, **kwargs)
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 163, in forward
[rank1]:     return F.embedding(
[rank1]:   File "/home/idefics2/idefics2_venv/lib/python3.10/site-packages/torch/nn/functional.py", line 2264, in embedding
[rank1]:     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[rank1]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)
W0429 01:59:14.322000 140198700533568 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 874502 closing signal SIGTERM
E0429 01:59:14.536000 140198700533568 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 1 (pid: 874503) of binary: /home/idefics2/idefics2_venv/bin/python3

default_config.yaml

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_config_file: /home/idefics2/deepspeed/deepspeed_config.json
  zero3_init_flag: true
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

deepspeed_config.json

{
    "communication_data_type": "fp32",
    "zero_optimization": {
        "stage": 3,
        "offload_param": {
            "device": "none"
        },
        "offload_optimizer": {
            "device": "none"
        }
    },
    "fp16": {
        "enabled": "auto"
    },
    "bf16": {
        "enabled": "auto"
    },
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": "auto"
}
HuggingFaceM4 org

hi @as8311
i usually encountered such error when I was handling the model device placement or input device placement myself instead of handing it off to the trainer (and accelerate in the backend).
any chance you are doing something similar?
btw, I would recommend in the ds config to not set fp16 AND bf16 to auto, but rather to true/false depending on which mixed precision you are using. i don't think it's related to this but it is less error prone

hey @VictorSanh
thanks for getting back to me. i'll update the ds config accordingly. i went through my code and didn't find any explicit handling for model or input device placement. here's the code for reference -

import safetensors
import torch
import random
import pandas as pd
from peft import LoraConfig
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics2ForConditionalGeneration
from datasets import Dataset
from PIL import Image
from transformers.image_utils import load_image
from peft import get_peft_model

RESUME_FROM_CHECKPOINT = False

processor = AutoProcessor.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    do_image_splitting=False,
)

USE_LORA = True
USE_QLORA = True

if USE_QLORA or USE_LORA:
    lora_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.1,
        target_modules = 'all-linear',
        use_dora=False if USE_QLORA else True,
        init_lora_weights="gaussian"
    )
    if USE_QLORA:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
    model = Idefics2ForConditionalGeneration.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        quantization_config=bnb_config if USE_QLORA else None,
        torch_dtype=torch.bfloat16, 
    )
else:
    model = Idefics2ForConditionalGeneration.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        torch_dtype=torch.bfloat16,
    )

model = get_peft_model(model, lora_config)


##
from datasets import load_dataset, disable_caching
disable_caching()

train = pd.read_csv("/home/idefics2/data/caption/train.csv", encoding_errors='ignore',lineterminator='\n',
                    on_bad_lines='skip')
train_dataset = Dataset.from_pandas(train)

##
import random

class MyDataCollator:
    def __init__(self, processor):
        self.processor = processor
        self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
            processor.tokenizer.additional_special_tokens.index("<image>")
        ]

    def __call__(self, examples):
        texts = []
        images = []
        for example in examples:
            image = load_image(example["main_image_url"])
            prompt = example["prompt"]
            answer = example["attribute_value_pair"]
            messages = [
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {"type": "image"},
                        ]
                    },
                    {
                        "role": "assistant",
                        "content": [
                            {"type": "text", "text": answer}
                        ]
                    }
            ]
            text = processor.apply_chat_template(messages, add_generation_prompt=False)
            texts.append(text.strip())
            images.append([image if image else None])

        batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id
        batch["labels"] = labels

        return batch

data_collator = MyDataCollator(processor)

##
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    num_train_epochs=2,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.1,
    logging_steps=25,
    output_dir="caption/",
    save_strategy="steps",
    save_steps=4000,
    save_total_limit=8,
    bf16=True,
    remove_unused_columns=False,
    report_to="tensorboard",
    deepspeed="deepspeed_config.json",
    save_safetensors=False,
    neftune_noise_alpha=5.0,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
)

trainer.train(resume_from_checkpoint=RESUME_FROM_CHECKPOINT)

Sign up or log in to comment