| import argparse |
| import torch |
| import cv2 |
| import os |
| import glob |
| import numpy as np |
| import ssl |
| |
| ssl._create_default_https_context = ssl._create_unverified_context |
|
|
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
| from src.models import DeepfakeDetector |
| from src.config import Config |
|
|
| try: |
| from safetensors.torch import load_file |
| SAFETENSORS_AVAILABLE = True |
| except ImportError: |
| SAFETENSORS_AVAILABLE = False |
|
|
| def get_transform(): |
| return A.Compose([ |
| A.Resize(Config.IMAGE_SIZE, Config.IMAGE_SIZE), |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ToTensorV2(), |
| ]) |
|
|
| def load_models(checkpoints_arg, device): |
| """ |
| Load one or multiple models for ensemble inference. |
| checkpoints_arg: Comma-separated list of paths, or single path, or directory. |
| """ |
| paths = [] |
| if os.path.isdir(checkpoints_arg): |
| paths = glob.glob(os.path.join(checkpoints_arg, "*.safetensors")) |
| if not paths: |
| paths = glob.glob(os.path.join(checkpoints_arg, "*.pth")) |
| else: |
| paths = checkpoints_arg.split(',') |
| |
| models = [] |
| print(f"Loading {len(paths)} model(s) for ensemble inference...") |
| |
| for path in paths: |
| path = path.strip() |
| if not path: continue |
| |
| print(f"Loading: {path}") |
| model = DeepfakeDetector(pretrained=False) |
| model.to(device) |
| model.eval() |
| |
| try: |
| if path.endswith(".safetensors") and SAFETENSORS_AVAILABLE: |
| state_dict = load_file(path) |
| else: |
| state_dict = torch.load(path, map_location=device) |
| model.load_state_dict(state_dict) |
| models.append(model) |
| print(f"✅ Successfully loaded: {os.path.basename(path)}") |
| except Exception as e: |
| |
| try: |
| print(f"⚠️ Initial load failed. Attempting legacy key remapping for {os.path.basename(path)}...") |
| from collections import OrderedDict |
| new_state_dict = OrderedDict() |
| for k, v in state_dict.items(): |
| if k.startswith('rgb_branch.features.'): |
| new_k = k.replace('rgb_branch.features.', 'rgb_branch.net.features.') |
| new_state_dict[new_k] = v |
| elif k.startswith('rgb_branch.avgpool.'): |
| new_k = k.replace('rgb_branch.avgpool.', 'rgb_branch.net.avgpool.') |
| new_state_dict[new_k] = v |
| else: |
| new_state_dict[k] = v |
| |
| model.load_state_dict(new_state_dict, strict=False) |
| models.append(model) |
| print(f"✅ Successfully loaded (with remapping): {os.path.basename(path)}") |
| except Exception as e2: |
| print(f"❌ Failed to load {path}: {e}") |
| print(f"❌ Remapping also failed: {e2}") |
| |
| if not models: |
| |
| print("Warning: No valid checkoints loaded. Using random initialization for testing flow.") |
| model = DeepfakeDetector(pretrained=False).to(device) |
| model.eval() |
| models.append(model) |
| |
| return models |
|
|
| def predict_ensemble(models, image_path, device, transform): |
| try: |
| image = cv2.imread(image_path) |
| if image is None: |
| return None, "Error: Could not read image" |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| except Exception as e: |
| return None, str(e) |
|
|
| augmented = transform(image=image) |
| image_tensor = augmented['image'].unsqueeze(0).to(device) |
| |
| probs = [] |
| with torch.no_grad(): |
| for model in models: |
| logits = model(image_tensor) |
| prob = torch.sigmoid(logits).item() |
| probs.append(prob) |
| |
| |
| avg_prob = sum(probs) / len(probs) |
| return avg_prob, None |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Deepfake Detection Inference (Ensemble Support)") |
| parser.add_argument("--source", type=str, required=True, help="Path to image or directory") |
| parser.add_argument("--checkpoints", type=str, default=Config.ACTIVE_MODEL_PATH, help="Path to checkpoint file or directory (Default: Mark-V)") |
| parser.add_argument("--device", type=str, default=Config.DEVICE, help="Device to use (cuda/mps/cpu)") |
| args = parser.parse_args() |
| |
| device = torch.device(args.device) |
| print(f"Using device: {device}") |
| |
| |
| models = load_models(args.checkpoints, device) |
| transform = get_transform() |
| |
| |
| if os.path.isdir(args.source): |
| files = glob.glob(os.path.join(args.source, "*.*")) |
| |
| files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] |
| else: |
| files = [args.source] |
| |
| print(f"Processing {len(files)} images with {len(models)} model(s)...") |
| print("-" * 65) |
| print(f"{'Image Name':<40} | {'Prediction':<10} | {'Confidence':<10}") |
| print("-" * 65) |
| |
| for file_path in files: |
| prob, error = predict_ensemble(models, file_path, device, transform) |
| if error: |
| print(f"{os.path.basename(file_path):<40} | ERROR: {error}") |
| continue |
| |
| is_fake = prob > 0.5 |
| label = "FAKE" if is_fake else "REAL" |
| confidence = prob if is_fake else 1 - prob |
| |
| print(f"{os.path.basename(file_path):<40} | {label:<10} | {confidence:.2%}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|