|
from datasets import load_dataset |
|
from linear_mapping import LinearMapping, LinearMappingProcessor, LinearMappingConfig, Transform |
|
import torch |
|
from torchvision.io import ImageReadMode, read_image |
|
from transformers import Trainer, TrainingArguments |
|
import os |
|
from PIL import Image |
|
os.environ["WANDB_DISABLED"] = "true" |
|
|
|
DATA_DIR = os.path.join(os.getcwd(), "coco") |
|
CAPTION_COLUMN = "caption" |
|
IMAGE_COLUMN = "image_path" |
|
|
|
|
|
def main(): |
|
ds = load_dataset("ydshieh/coco_dataset_script", "2017", DATA_DIR) |
|
config = LinearMappingConfig() |
|
processor = LinearMappingProcessor(config) |
|
|
|
def collate_fn(batch): |
|
return { |
|
'pixel_values': torch.stack([x['pixel_values'] for x in batch]), |
|
'input_ids': torch.tensor([x['input_ids'] for x in batch], dtype=torch.long), |
|
'attention_mask': torch.stack([x["attention_mask"] for x in batch]), |
|
} |
|
|
|
def tokenize_fn(examples): |
|
texts = list(examples[CAPTION_COLUMN]) |
|
if config.add_image_token: |
|
texts = list(processor.tokenizer.cls_token + text for text in texts) |
|
inputs = processor.tokenizer( |
|
texts, padding="max_length", max_length=77, |
|
return_tensors="pt", truncation=True |
|
) |
|
examples["input_ids"] = inputs.input_ids |
|
examples["attention_mask"] = inputs.attention_mask |
|
return examples |
|
|
|
image_transformations = Transform( |
|
config.image_resize, |
|
[0.48145466, 0.4578275, 0.40821073], |
|
[0.26862954, 0.26130258, 0.27577711] |
|
) |
|
image_transformations = torch.jit.script(image_transformations) |
|
|
|
def transform_images(examples): |
|
images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[IMAGE_COLUMN]] |
|
examples["pixel_values"] = [image_transformations(image) for image in images] |
|
|
|
examples["attention_mask"] = torch.cat([ |
|
torch.ones(len(images), config.prefix_length), |
|
torch.tensor(examples["attention_mask"]) |
|
], dim=1).to(dtype=torch.long) |
|
return examples |
|
|
|
def preprocess_fn(examples): |
|
|
|
texts = list(examples[CAPTION_COLUMN]) |
|
|
|
images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[IMAGE_COLUMN]] |
|
inputs = processor( |
|
texts=texts, images=images, padding="max_length", truncation=True, max_length=77, return_tensors="pt" |
|
) |
|
return inputs |
|
|
|
def filter_corrupt_images(examples): |
|
"""remove problematic images""" |
|
valid_images = [] |
|
for image_file in examples[IMAGE_COLUMN]: |
|
try: |
|
Image.open(image_file) |
|
valid_images.append(True) |
|
except Exception: |
|
valid_images.append(False) |
|
return valid_images |
|
|
|
train_dataset = ds["train"] |
|
|
|
train_dataset = train_dataset.filter( |
|
function=filter_corrupt_images, |
|
batched=True |
|
) |
|
train_dataset = train_dataset.map( |
|
function=tokenize_fn, |
|
batched=True, |
|
remove_columns=[col for col in train_dataset.column_names if col != IMAGE_COLUMN and col != CAPTION_COLUMN], |
|
load_from_cache_file=True |
|
) |
|
train_dataset.set_transform(transform_images) |
|
|
|
training_args = TrainingArguments( |
|
learning_rate=5e-4, |
|
lr_scheduler_type='cosine', |
|
output_dir='clip-gpt2-image-captioner', |
|
do_train=True, |
|
logging_steps=50, |
|
num_train_epochs=5, |
|
logging_dir='runs', |
|
remove_unused_columns=False, |
|
max_grad_norm=1.0, |
|
per_device_train_batch_size=16, |
|
save_total_limit=3, |
|
warmup_steps=500 |
|
) |
|
model = LinearMapping(config) |
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
data_collator=collate_fn |
|
) |
|
trainer.train() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|