Spaces:
Runtime error
Runtime error
File size: 5,630 Bytes
577d9ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
from processor import MultiModalProcessor
from load_model import load_hf_model
from transformers import Trainer, TrainingArguments
from dataclasses import dataclass, field
from typing import List
@dataclass
class LoraConfig:
r: int = 8
lora_alpha: int = 16
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
lora_dropout: float = 0.05
bias: str = "none"
task_type: str = "CAUSAL_LM"
def __post_init__(self):
self.inference_mode = False
self.r = {}
self.lora_alpha = {}
self.scaling = {}
self.lora_dropout = {}
for key in self.target_modules:
self.r[key] = self.r
self.lora_alpha[key] = self.lora_alpha
self.scaling[key] = self.lora_alpha[key] / self.r[key]
self.lora_dropout[key] = self.lora_dropout
class LoraLinear(torch.nn.Module):
def __init__(self, in_features, out_features, config: LoraConfig):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
self.lora_A = torch.nn.Parameter(torch.zeros((config.r, in_features)))
self.lora_B = torch.nn.Parameter(torch.zeros((out_features, config.r)))
self.scaling = config.scaling
self.dropout = torch.nn.Dropout(p=config.lora_dropout)
def forward(self, x):
result = self.linear(x)
lora_output = (self.dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
return result + lora_output
def apply_lora_to_model(model, config: LoraConfig):
for name, module in model.named_modules():
if any(target in name for target in config.target_modules):
if isinstance(module, torch.nn.Linear):
lora_module = LoraLinear(module.in_features, module.out_features, config)
lora_module.linear.weight.data = module.weight.data
if module.bias is not None:
lora_module.linear.bias = module.bias
setattr(model, name, lora_module)
return model
# Load the dataset
ds = load_dataset('HuggingFaceM4/VQAv2', split="train[:10%]")
cols_remove = ["question_type", "answers", "answer_type", "image_id", "question_id"]
ds = ds.remove_columns(cols_remove)
# Create a small test split
split_ds = ds.train_test_split(test_size=0.05)
train_ds = split_ds["train"]
test_ds = split_ds["test"]
print(train_ds)
print(test_ds)
# Load the model and processor
model_id = "./paligemma-3b-pt-224"
model, tokenizer = load_hf_model(model_id, "cuda")
processor = MultiModalProcessor(tokenizer, model.config.vision_config.num_image_tokens, model.config.vision_config.image_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Apply LoRA to the model
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"], lora_dropout=0.05)
model = apply_lora_to_model(model, lora_config)
# Define a custom dataset
class PaliGemmaDataset(Dataset):
def __init__(self, dataset, processor):
self.dataset = dataset
self.processor = processor
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
prompt = "answer " + item["question"]
image = item["image"].convert("RGB")
answer = item["multiple_choice_answer"]
# Process inputs
inputs = self.processor(text=[prompt], images=[image])
# Process labels
label_inputs = self.processor(text=[answer], images=[image])
labels = label_inputs['input_ids'][0]
# Set the labels to -100 for the input part (we don't want to compute loss on it)
inputs['labels'] = torch.full_like(inputs['input_ids'][0], -100)
inputs['labels'][-len(labels):] = torch.tensor(labels)
return inputs
# Create datasets
train_dataset = PaliGemmaDataset(train_ds, processor)
eval_dataset = PaliGemmaDataset(test_ds, processor)
# Define a custom data collator
def custom_data_collator(features):
batch = {
'pixel_values': torch.stack([f['pixel_values'][0] for f in features]),
'input_ids': torch.stack([f['input_ids'][0] for f in features]),
'attention_mask': torch.stack([f['attention_mask'][0] for f in features]),
'labels': torch.stack([f['labels'] for f in features])
}
return batch
# Define training arguments
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=custom_data_collator,
)
# Fine-tune the model
trainer.train()
# Save the fine-tuned model
trainer.save_model("lora_paligemma_vqa")
# Function to save LoRA weights separately
def save_lora_weights(model, path):
lora_state_dict = {}
for name, module in model.named_modules():
if isinstance(module, LoraLinear):
lora_state_dict[f"{name}.lora_A"] = module.lora_A.data
lora_state_dict[f"{name}.lora_B"] = module.lora_B.data
torch.save(lora_state_dict, path)
# Save LoRA weights
save_lora_weights(model, "lora_weights.pt")
print("Fine-tuning completed and model saved.") |