Spaces:
Build error
Build error
| import os | |
| import torch | |
| from datasets import load_dataset, ClassLabel, Image | |
| from transformers import ( | |
| ViTImageProcessor, | |
| ViTForImageClassification, | |
| TrainingArguments, | |
| Trainer, | |
| DefaultDataCollator, | |
| ) | |
| import evaluate | |
| from torchvision.transforms import ( | |
| CenterCrop, | |
| Compose, | |
| Normalize, | |
| RandomRotation, | |
| RandomResizedCrop, | |
| RandomHorizontalFlip, | |
| RandomAdjustSharpness, | |
| Resize, | |
| ToTensor, | |
| ) | |
| import numpy as np | |
| # --- Configuration --- | |
| MODEL_NAME = "google/vit-base-patch16-224" | |
| DATASET_DIR = "./dataset" | |
| OUTPUT_DIR = "./model" | |
| BATCH_SIZE = 16 | |
| NUM_EPOCHS = 3 | |
| LEARNING_RATE = 2e-5 | |
| def main(): | |
| # 1. Load Dataset | |
| print("Loading dataset...") | |
| # Expects dataset structure: dataset/train/LABEL and dataset/test/LABEL | |
| data_files = {} | |
| if os.path.exists(os.path.join(DATASET_DIR, "train")): | |
| data_files["train"] = os.path.join(DATASET_DIR, "train") | |
| if os.path.exists(os.path.join(DATASET_DIR, "test")): | |
| data_files["test"] = os.path.join(DATASET_DIR, "test") | |
| if not data_files: | |
| print(f"Error: No data found in {DATASET_DIR}. Please organize data in 'train' and 'test' folders.") | |
| print("Expected structure: ./dataset/train/REAL, ./dataset/train/FAKE, etc.") | |
| return | |
| # Use evaluate load logic or simplified imagefolder loading | |
| # Ideally use Hugging Face datasets ImageFolder builder which is automatic if we point to directory | |
| dataset = load_dataset("imagefolder", data_dir=DATASET_DIR) | |
| # 2. Labels | |
| labels = dataset["train"].features["label"].names | |
| id2label = {str(i): c for i, c in enumerate(labels)} | |
| label2id = {c: str(i) for i, c in enumerate(labels)} | |
| print(f"Labels found: {labels}") | |
| # 3. Preprocessing | |
| processor = ViTImageProcessor.from_pretrained(MODEL_NAME) | |
| image_mean = processor.image_mean | |
| image_std = processor.image_std | |
| size = processor.size["height"] | |
| normalize = Normalize(mean=image_mean, std=image_std) | |
| _train_transforms = Compose([ | |
| RandomResizedCrop(size), | |
| RandomHorizontalFlip(), | |
| RandomAdjustSharpness(2), | |
| ToTensor(), | |
| normalize, | |
| ]) | |
| _val_transforms = Compose([ | |
| Resize(size), | |
| CenterCrop(size), | |
| ToTensor(), | |
| normalize, | |
| ]) | |
| def train_transforms(examples): | |
| examples["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in examples["image"]] | |
| return examples | |
| def val_transforms(examples): | |
| examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]] | |
| return examples | |
| # Apply transforms | |
| print("Applying transforms...") | |
| dataset["train"].set_transform(train_transforms) | |
| if "test" in dataset: | |
| dataset["test"].set_transform(val_transforms) | |
| # 4. Model | |
| print(f"Loading model {MODEL_NAME}...") | |
| model = ViTForImageClassification.from_pretrained( | |
| MODEL_NAME, | |
| num_labels=len(labels), | |
| id2label=id2label, | |
| label2id=label2id, | |
| ignore_mismatched_sizes=True | |
| ) | |
| # 5. Metrics | |
| metric = evaluate.load("accuracy") | |
| def compute_metrics(eval_pred): | |
| predictions = np.argmax(eval_pred.predictions, axis=1) | |
| return metric.compute(predictions=predictions, references=eval_pred.label_ids) | |
| # 6. Training Arguments | |
| args = TrainingArguments( | |
| output_dir=OUTPUT_DIR, | |
| remove_unused_columns=False, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| learning_rate=LEARNING_RATE, | |
| per_device_train_batch_size=BATCH_SIZE, | |
| per_device_eval_batch_size=BATCH_SIZE, | |
| num_train_epochs=NUM_EPOCHS, | |
| warmup_ratio=0.1, | |
| logging_steps=10, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy", | |
| push_to_hub=False, | |
| ) | |
| collator = DefaultDataCollator() | |
| trainer = Trainer( | |
| model=model, | |
| args=args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["test"] if "test" in dataset else None, | |
| tokenizer=processor, | |
| data_collator=collator, | |
| compute_metrics=compute_metrics, | |
| ) | |
| # 7. Train | |
| print("Starting training...") | |
| trainer.train() | |
| # 8. Save | |
| print(f"Saving model to {OUTPUT_DIR}...") | |
| trainer.save_model(OUTPUT_DIR) | |
| processor.save_pretrained(OUTPUT_DIR) | |
| print("Done!") | |
| if __name__ == "__main__": | |
| main() | |