File size: 4,352 Bytes
9b5fe77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# install dependencies
# !pip install -U "transformers>=4.39.0"
# !pip install peft bitsandbytes
# !pip install -U "trl>=0.8.3"

import torch
import io
import PIL.Image as Image
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration, BitsAndBytesConfig
from trl import SFTTrainer
from peft import LoraConfig
from datasets import Dataset
from huggingface_hub import notebook_login


def load_llava_model(model_id):
    quantization_config = BitsAndBytesConfig(load_in_4bit=True)
    model = LlavaForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config,
                                                          torch_dtype=torch.float16)
    torch.cuda.empty_cache()
    return model

class LLavaDataCollator:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, examples):
        texts = []
        images = []
        for example in examples:
            messages = example["messages"]
            text = self.processor.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
            texts.append(text)
            img = Image.open(io.BytesIO(example['images'][0]['bytes']))
            images.append(img)

        batch = self.processor(texts, images, return_tensors="pt", padding=True)

        labels = batch["input_ids"].clone()
        if self.processor.tokenizer.pad_token_id is not None:
            labels[labels == self.processor.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        return batch


def load_datasets(train_path, test_path):
    train_dataset = Dataset.load_from_disk(train_path)
    eval_dataset = Dataset.load_from_disk(test_path)
    return train_dataset, eval_dataset


def configure_training_args(output_dir="llava-1.5-7b-hf-ft-museum"):
    training_args = TrainingArguments(
        output_dir=output_dir,
        report_to="tensorboard",
        learning_rate=1.4e-5,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=1,
        logging_steps=5,
        num_train_epochs=5,
        push_to_hub=True,
        gradient_checkpointing=True,
        remove_unused_columns=False,
        fp16=True,
        bf16=False
    )
    return training_args


def configure_llora_config():
    lora_config = LoraConfig(
        r=32,
        lora_alpha=16,
        target_modules="all-linear"
    )
    return lora_config


def configure_trainer(model, args, train_dataset, eval_dataset, lora_config, tokenizer, data_collator):
    trainer = SFTTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        peft_config=lora_config,
        dataset_text_field="text",  # need a dummy field
        tokenizer=tokenizer,
        data_collator=data_collator,
        dataset_kwargs={"skip_prepare_dataset": True},
    )
    return trainer


def save_model(model, output_dir="llava_model_museum"):
    training_args = TrainingArguments(output_dir=output_dir, push_to_hub=False)
    model.save_pretrained(output_dir, push_to_hub=training_args.push_to_hub)


def main():
    # Load LLAVA model
    model_id = "llava-hf/llava-1.5-7b-hf"
    model = load_llava_model(model_id)

    # Load datasets
    train_path = '/kaggle/input/metmesuem-data/train.arrow'
    test_path = '/kaggle/input/metmesuem-data/test.arrow'
    train_dataset, eval_dataset = load_datasets(train_path, test_path)

    # Login to Hugging-face Hub
    notebook_login()

    # Configure training arguments
    args = configure_training_args()

    # Configure LLORA config
    lora_config = configure_llora_config()

    # Configure trainer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    processor = AutoProcessor.from_pretrained(model_id)
    processor.tokenizer = tokenizer
    data_collator = LLavaDataCollator(processor)
    trainer = configure_trainer(model, args, train_dataset, eval_dataset, lora_config, tokenizer, data_collator)

    # Load and set Tensorboard for logging
    # %load_ext tensorboard
    # %tensorboard --logdir llava-1.5-7b-hf-ft-museum

    # Train model
    trainer.train()

    # Save the Lora adapters to Hugging Face Hub
    trainer.push_to_hub()

    # Save model
    save_model(model)


if __name__ == "__main__":
    main()