clip-gpt2 / main.py
bczhou's picture
commit demo to space
859131c
raw history blame
No virus
3.88 kB
from datasets import load_dataset
from linear_mapping import LinearMapping, LinearMappingProcessor, LinearMappingConfig, Transform
import torch
from torchvision.io import ImageReadMode, read_image
from transformers import Trainer, TrainingArguments
import os
from PIL import Image
os.environ["WANDB_DISABLED"] = "true"
DATA_DIR = os.path.join(os.getcwd(), "coco")
CAPTION_COLUMN = "caption"
IMAGE_COLUMN = "image_path"
def main():
ds = load_dataset("ydshieh/coco_dataset_script", "2017", DATA_DIR)
config = LinearMappingConfig()
processor = LinearMappingProcessor(config)
def collate_fn(batch):
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'input_ids': torch.tensor([x['input_ids'] for x in batch], dtype=torch.long),
'attention_mask': torch.stack([x["attention_mask"] for x in batch]),
}
def tokenize_fn(examples):
texts = list(examples[CAPTION_COLUMN])
if config.add_image_token:
texts = list(processor.tokenizer.cls_token + text for text in texts)
inputs = processor.tokenizer(
texts, padding="max_length", max_length=77,
return_tensors="pt", truncation=True
)
examples["input_ids"] = inputs.input_ids
examples["attention_mask"] = inputs.attention_mask
return examples
image_transformations = Transform(
config.image_resize,
[0.48145466, 0.4578275, 0.40821073],
[0.26862954, 0.26130258, 0.27577711]
)
image_transformations = torch.jit.script(image_transformations)
def transform_images(examples):
images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[IMAGE_COLUMN]]
examples["pixel_values"] = [image_transformations(image) for image in images]
examples["attention_mask"] = torch.cat([
torch.ones(len(images), config.prefix_length),
torch.tensor(examples["attention_mask"])
], dim=1).to(dtype=torch.long)
return examples
def preprocess_fn(examples):
texts = list(examples[CAPTION_COLUMN])
images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[IMAGE_COLUMN]]
inputs = processor(
texts=texts, images=images, padding="max_length", truncation=True, max_length=77, return_tensors="pt"
)
return inputs
def filter_corrupt_images(examples):
"""remove problematic images"""
valid_images = []
for image_file in examples[IMAGE_COLUMN]:
try:
Image.open(image_file)
valid_images.append(True)
except Exception:
valid_images.append(False)
return valid_images
train_dataset = ds["train"]
train_dataset = train_dataset.filter(
function=filter_corrupt_images,
batched=True
)
train_dataset = train_dataset.map(
function=tokenize_fn,
batched=True,
remove_columns=[col for col in train_dataset.column_names if col != IMAGE_COLUMN and col != CAPTION_COLUMN],
load_from_cache_file=True
)
train_dataset.set_transform(transform_images)
training_args = TrainingArguments(
learning_rate=5e-4,
lr_scheduler_type='cosine',
output_dir='clip-gpt2-image-captioner',
do_train=True,
logging_steps=50,
num_train_epochs=5,
logging_dir='runs',
remove_unused_columns=False,
max_grad_norm=1.0,
per_device_train_batch_size=16,
save_total_limit=3,
warmup_steps=500
)
model = LinearMapping(config)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=collate_fn
)
trainer.train()
if __name__ == '__main__':
main()