Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import json | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| from loguru import logger | |
| from sklearn.metrics import average_precision_score | |
| from utils import seed_torch | |
| from Detectors import ArtifactDetector, SemanticDetector | |
| from Datasets import TrainDataset, TestDataset, EVAL_DATASET_LIST, EVAL_MODEL_LIST | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| class Detector(): | |
| def __init__(self, args): | |
| super(Detector, self).__init__() | |
| # Device | |
| self.device = args.device | |
| # Get the detector | |
| if args.detector == "artifact": | |
| self.model = ArtifactDetector() | |
| elif args.detector == "semantic": | |
| self.model = SemanticDetector() | |
| else: | |
| raise ValueError("Unknown detector") | |
| # Put the model on the device | |
| self.model.to(self.device) | |
| # Initialize the fc layer | |
| torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02) | |
| # Optimizer | |
| _lr = 1e-4 | |
| _beta1 = 0.9 | |
| _weight_decay = 0.0 | |
| params = [p for p in self.model.parameters() if p.requires_grad] | |
| print(f"Trainable parameters: {len(params)}") | |
| self.optimizer = torch.optim.AdamW(params, lr=_lr, betas=(_beta1, 0.999), weight_decay=_weight_decay) | |
| # Loss function | |
| self.criterion = torch.nn.BCEWithLogitsLoss() | |
| # Scheduler | |
| self.delr_freq = 10 | |
| # Resume info | |
| self.start_epoch = 0 | |
| self.best_acc = 0.0 | |
| def train_step(self, batch_data): | |
| inputs, labels = batch_data | |
| inputs, labels = inputs.to(self.device), labels.to(self.device) | |
| self.optimizer.zero_grad() | |
| outputs = self.model(inputs) | |
| loss = self.criterion(outputs, labels.unsqueeze(1).float()) | |
| loss.backward() | |
| self.optimizer.step() | |
| eval_loss = loss.item() | |
| y_pred = outputs.sigmoid().flatten().tolist() | |
| y_true = labels.tolist() | |
| return eval_loss, y_pred, y_true | |
| def scheduler(self, status_dict): | |
| epoch = status_dict["epoch"] | |
| if epoch % self.delr_freq == 0 and epoch != 0: | |
| for param_group in self.optimizer.param_groups: | |
| param_group["lr"] *= 0.9 | |
| self.lr = param_group["lr"] | |
| return True | |
| def predict(self, inputs): | |
| inputs = inputs.to(self.device) | |
| outputs = self.model(inputs) | |
| return outputs.sigmoid().flatten().tolist() | |
| # --- Checkpoint functions --- | |
| def save_checkpoint(self, path, epoch, best_acc): | |
| torch.save({ | |
| "epoch": epoch, | |
| "best_acc": best_acc, | |
| "model_state": self.model.state_dict(), | |
| "optimizer_state": self.optimizer.state_dict() | |
| }, path) | |
| def load_checkpoint(self, path): | |
| if os.path.exists(path): | |
| ckpt = torch.load(path, map_location=self.device) | |
| self.model.load_state_dict(ckpt["model_state"]) | |
| self.optimizer.load_state_dict(ckpt["optimizer_state"]) | |
| self.start_epoch = ckpt.get("epoch", 0) + 1 | |
| self.best_acc = ckpt.get("best_acc", 0.0) | |
| print(f"[INFO] Loaded checkpoint '{path}' (start_epoch={self.start_epoch}, best_acc={self.best_acc})") | |
| else: | |
| print(f"[WARNING] Checkpoint not found: {path}") | |
| def evaluate(y_pred, y_true): | |
| ap = average_precision_score(y_true, y_pred) | |
| accuracy = ((np.array(y_pred) > 0.5) == y_true).mean() | |
| return ap, accuracy | |
| def train(args): | |
| # Get the detector | |
| detector = Detector(args) | |
| # --- Resume checkpoint --- | |
| start_epoch = 0 | |
| best_acc = 0 | |
| if args.resume != "": | |
| if os.path.exists(args.resume): | |
| print(f"[INFO] Loading checkpoint from {args.resume}") | |
| ckpt = torch.load(args.resume, map_location=args.device) | |
| detector.model.load_weights(args.resume) | |
| # Nếu lưu thêm optimizer & best_acc, load ở đây | |
| if "best_acc" in ckpt: | |
| best_acc = ckpt["best_acc"] | |
| if "epoch" in ckpt: | |
| start_epoch = ckpt["epoch"] + 1 | |
| else: | |
| print(f"[WARNING] Resume checkpoint not found: {args.resume}") | |
| # Load datasets | |
| train_dataset = TrainDataset(data_path=args.trainset_dirpath, | |
| split="train", | |
| transform=detector.model.train_transform) | |
| train_loader = torch.utils.data.DataLoader(train_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=4, | |
| pin_memory=True) | |
| test_dataset = TrainDataset(data_path=args.trainset_dirpath, | |
| split="val", | |
| transform=detector.model.test_transform) | |
| test_loader = torch.utils.data.DataLoader(test_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=4, | |
| pin_memory=True) | |
| logger.info(f"Train size {len(train_dataset)} | Test size {len(test_dataset)}") | |
| # Set saving directory | |
| model_dir = os.path.join(args.ckpt, args.detector) | |
| os.makedirs(model_dir, exist_ok=True) | |
| log_path = f"{model_dir}/training.log" | |
| if os.path.exists(log_path): | |
| os.remove(log_path) | |
| logger_id = logger.add(log_path, format="{time:MM-DD at HH:mm:ss} | {level} | {module}:{line} | {message}", level="DEBUG") | |
| # Train loop | |
| for epoch in range(start_epoch, args.epochs): | |
| detector.model.train() | |
| time_start = time.time() | |
| for step_id, batch_data in enumerate(train_loader): | |
| eval_loss, y_pred, y_true = detector.train_step(batch_data) | |
| ap, accuracy = evaluate(y_pred, y_true) | |
| if (step_id + 1) % 100 == 0: | |
| time_end = time.time() | |
| logger.info(f"Epoch {epoch} | Batch {step_id + 1}/{len(train_loader)} | Loss {eval_loss:.4f} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}% | Time {time_end-time_start:.2f}s") | |
| time_start = time.time() | |
| # Evaluate | |
| detector.model.eval() | |
| y_pred, y_true = [], [] | |
| for (images, labels) in test_loader: | |
| y_pred.extend(detector.predict(images)) | |
| y_true.extend(labels.tolist()) | |
| ap, accuracy = evaluate(y_pred, y_true) | |
| logger.info(f"Epoch {epoch} | Test AP {ap*100:.2f}% | Test Accuracy {accuracy*100:.2f}%") | |
| # Save best model | |
| if accuracy >= best_acc: | |
| best_acc = accuracy | |
| detector.model.save_weights(f"{model_dir}/best_model.pth") | |
| torch.save({"epoch": epoch, "best_acc": best_acc}, f"{model_dir}/best_model_meta.pth") | |
| logger.info(f"Best model saved with accuracy {best_acc*100:.2f}%") | |
| # Save periodic checkpoints | |
| if epoch % 5 == 0: | |
| detector.model.save_weights(f"{model_dir}/epoch_{epoch}.pth") | |
| logger.info(f"Model saved at epoch {epoch}") | |
| # Save final model | |
| detector.model.save_weights(f"{model_dir}/final_model.pth") | |
| logger.info("Final model saved") | |
| logger.remove(logger_id) | |
| def test(args): | |
| # Initialize the detector | |
| detector = Detector(args) | |
| # --- Load checkpoint if resume is provided --- | |
| if args.resume != "": | |
| ckpt_path = args.resume | |
| if os.path.exists(ckpt_path): | |
| print(f"[INFO] Loading checkpoint from {ckpt_path}") | |
| detector.model.load_weights(ckpt_path) | |
| else: | |
| print(f"[WARNING] Resume checkpoint not found: {ckpt_path}") | |
| # Load the [best/final] model | |
| weights_path = os.path.join(args.ckpt, args.detector, "best_model.pth") | |
| detector.model.load_weights(weights_path) | |
| detector.model.to(args.device) | |
| detector.model.eval() | |
| # Set the pre-processing function | |
| test_transform = detector.model.test_transform | |
| # Set the saving directory | |
| save_dir = os.path.join(args.ckpt, args.detector) | |
| save_result_path = os.path.join(save_dir, "result.json") | |
| save_output_path = os.path.join(save_dir, "output.json") | |
| # Begin the evaluation | |
| result_all = {} | |
| output_all = {} | |
| for dataset_name in EVAL_DATASET_LIST: | |
| result_all[dataset_name] = {} | |
| output_all[dataset_name] = {} | |
| for model_name in EVAL_MODEL_LIST: | |
| test_dataset = TestDataset(dataset=dataset_name, model=model_name, root_path=args.testset_dirpath, transform=test_transform) | |
| test_loader = torch.utils.data.DataLoader(test_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=4, | |
| pin_memory=True) | |
| # Evaluate the model | |
| y_pred, y_true = [], [] | |
| for (images, labels) in tqdm(test_loader, desc=f"Evaluating {dataset_name} {model_name}"): | |
| y_pred.extend(detector.predict(images)) | |
| y_true.extend(labels.tolist()) | |
| ap, accuracy = evaluate(y_pred, y_true) | |
| print(f"Evaluate on {dataset_name} {model_name} | Size {len(y_true)} | AP {ap*100:.2f}% | Accuracy {accuracy*100:.2f}%") | |
| result_all[dataset_name][model_name] = {"size": len(y_true), "AP": ap, "Accuracy": accuracy} | |
| output_all[dataset_name][model_name] = {"y_pred": y_pred, "y_true": y_true} | |
| # Save the results | |
| with open(save_result_path, "w") as f: | |
| json.dump(result_all, f, indent=4) | |
| with open(save_output_path, "w") as f: | |
| json.dump(output_all, f, indent=4) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser("Deep Fake Detection") | |
| parser.add_argument("--gpu", type=int, default=0, help="GPU ID") | |
| parser.add_argument("--phase", type=str, default="test", choices=["train", "test"], help="Phase of the experiment") | |
| parser.add_argument("--detector", type=str, default="artifact", choices=["artifact", "semantic"], help="Detector to use") | |
| parser.add_argument("--trainset_dirpath", type=str, default="data/train", help="Trainset directory") | |
| parser.add_argument("--testset_dirpath", type=str, default="data/test", help="Testset directory") | |
| parser.add_argument("--ckpt", type=str, default="ckpt", help="Checkpoint directory") | |
| parser.add_argument("--epochs", type=int, default=10, help="Number of epochs") | |
| parser.add_argument("--batch_size", type=int, default=32, help="Batch size") | |
| parser.add_argument("--seed", type=int, default=1024, help="Random seed") | |
| parser.add_argument("--resume", type=str, default="", help="Path to checkpoint to resume training") | |
| args = parser.parse_args() | |
| # Set the random seed | |
| seed_torch(args.seed) | |
| # Set the GPU ID | |
| args.device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu" | |
| # Begin the experiment | |
| if args.phase == "train": | |
| train(args) | |
| elif args.phase == "test": | |
| test(args) | |
| else: | |
| raise ValueError("Unknown phase") | |