| | import os
|
| | import torch
|
| | from datasets import load_dataset, ClassLabel, Image
|
| | from transformers import (
|
| | ViTImageProcessor,
|
| | ViTForImageClassification,
|
| | TrainingArguments,
|
| | Trainer,
|
| | DefaultDataCollator,
|
| | )
|
| | import evaluate
|
| | from torchvision.transforms import (
|
| | CenterCrop,
|
| | Compose,
|
| | Normalize,
|
| | RandomRotation,
|
| | RandomResizedCrop,
|
| | RandomHorizontalFlip,
|
| | RandomAdjustSharpness,
|
| | Resize,
|
| | ToTensor,
|
| | )
|
| | import numpy as np
|
| |
|
| |
|
| | MODEL_NAME = "google/vit-base-patch16-224"
|
| | DATASET_DIR = "./dataset"
|
| | OUTPUT_DIR = "./model"
|
| | BATCH_SIZE = 16
|
| | NUM_EPOCHS = 3
|
| | LEARNING_RATE = 2e-5
|
| |
|
| | def main():
|
| |
|
| | print("Loading dataset...")
|
| |
|
| | data_files = {}
|
| | if os.path.exists(os.path.join(DATASET_DIR, "train")):
|
| | data_files["train"] = os.path.join(DATASET_DIR, "train")
|
| | if os.path.exists(os.path.join(DATASET_DIR, "test")):
|
| | data_files["test"] = os.path.join(DATASET_DIR, "test")
|
| |
|
| | if not data_files:
|
| | print(f"Error: No data found in {DATASET_DIR}. Please organize data in 'train' and 'test' folders.")
|
| | print("Expected structure: ./dataset/train/REAL, ./dataset/train/FAKE, etc.")
|
| | return
|
| |
|
| |
|
| |
|
| | dataset = load_dataset("imagefolder", data_dir=DATASET_DIR)
|
| |
|
| |
|
| | labels = dataset["train"].features["label"].names
|
| | id2label = {str(i): c for i, c in enumerate(labels)}
|
| | label2id = {c: str(i) for i, c in enumerate(labels)}
|
| | print(f"Labels found: {labels}")
|
| |
|
| |
|
| | processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
|
| | image_mean = processor.image_mean
|
| | image_std = processor.image_std
|
| | size = processor.size["height"]
|
| |
|
| | normalize = Normalize(mean=image_mean, std=image_std)
|
| |
|
| | _train_transforms = Compose([
|
| | RandomResizedCrop(size),
|
| | RandomHorizontalFlip(),
|
| | RandomAdjustSharpness(2),
|
| | ToTensor(),
|
| | normalize,
|
| | ])
|
| |
|
| | _val_transforms = Compose([
|
| | Resize(size),
|
| | CenterCrop(size),
|
| | ToTensor(),
|
| | normalize,
|
| | ])
|
| |
|
| | def train_transforms(examples):
|
| | examples["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in examples["image"]]
|
| | return examples
|
| |
|
| | def val_transforms(examples):
|
| | examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]]
|
| | return examples
|
| |
|
| |
|
| | print("Applying transforms...")
|
| | dataset["train"].set_transform(train_transforms)
|
| | if "test" in dataset:
|
| | dataset["test"].set_transform(val_transforms)
|
| |
|
| |
|
| | print(f"Loading model {MODEL_NAME}...")
|
| | model = ViTForImageClassification.from_pretrained(
|
| | MODEL_NAME,
|
| | num_labels=len(labels),
|
| | id2label=id2label,
|
| | label2id=label2id,
|
| | ignore_mismatched_sizes=True
|
| | )
|
| |
|
| |
|
| | metric = evaluate.load("accuracy")
|
| | def compute_metrics(eval_pred):
|
| | predictions = np.argmax(eval_pred.predictions, axis=1)
|
| | return metric.compute(predictions=predictions, references=eval_pred.label_ids)
|
| |
|
| |
|
| | args = TrainingArguments(
|
| | output_dir=OUTPUT_DIR,
|
| | remove_unused_columns=False,
|
| | evaluation_strategy="epoch",
|
| | save_strategy="epoch",
|
| | learning_rate=LEARNING_RATE,
|
| | per_device_train_batch_size=BATCH_SIZE,
|
| | per_device_eval_batch_size=BATCH_SIZE,
|
| | num_train_epochs=NUM_EPOCHS,
|
| | warmup_ratio=0.1,
|
| | logging_steps=10,
|
| | load_best_model_at_end=True,
|
| | metric_for_best_model="accuracy",
|
| | push_to_hub=False,
|
| | )
|
| |
|
| | collator = DefaultDataCollator()
|
| |
|
| | trainer = Trainer(
|
| | model=model,
|
| | args=args,
|
| | train_dataset=dataset["train"],
|
| | eval_dataset=dataset["test"] if "test" in dataset else None,
|
| | tokenizer=processor,
|
| | data_collator=collator,
|
| | compute_metrics=compute_metrics,
|
| | )
|
| |
|
| |
|
| | print("Starting training...")
|
| | trainer.train()
|
| |
|
| |
|
| | print(f"Saving model to {OUTPUT_DIR}...")
|
| | trainer.save_model(OUTPUT_DIR)
|
| | processor.save_pretrained(OUTPUT_DIR)
|
| | print("Done!")
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|