ash12321's picture
Upload model.py with huggingface_hub
bb9059d verified
"""
ResidualConvAutoencoder - Deepfake Detection Model
Architecture: 5-stage encoder-decoder with residual blocks
"""
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
"""Residual block with two conv layers and skip connection"""
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += residual
return self.relu(out)
class ResidualConvAutoencoder(nn.Module):
"""
Residual Convolutional Autoencoder for image reconstruction and deepfake detection.
Args:
latent_dim (int): Dimension of latent space (default: 512)
Input:
x: Tensor of shape (batch_size, 3, 128, 128), values in [-1, 1]
Output:
reconstructed: Tensor of shape (batch_size, 3, 128, 128), values in [-1, 1]
latent: Tensor of shape (batch_size, latent_dim)
"""
def __init__(self, latent_dim=512):
super().__init__()
self.latent_dim = latent_dim
# Encoder: 128x128 -> 4x4
self.encoder = nn.Sequential(
# Stage 1: 128 -> 64
nn.Conv2d(3, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
ResidualBlock(64),
# Stage 2: 64 -> 32
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
ResidualBlock(128),
# Stage 3: 32 -> 16
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
ResidualBlock(256),
# Stage 4: 16 -> 8
nn.Conv2d(256, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
ResidualBlock(512),
# Stage 5: 8 -> 4
nn.Conv2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
# Bottleneck
self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim)
self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4)
# Decoder: 4x4 -> 128x128
self.decoder = nn.Sequential(
# Stage 1: 4 -> 8
nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
ResidualBlock(512),
# Stage 2: 8 -> 16
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
ResidualBlock(256),
# Stage 3: 16 -> 32
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
ResidualBlock(128),
# Stage 4: 32 -> 64
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
ResidualBlock(64),
# Stage 5: 64 -> 128
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
nn.Tanh() # Output in [-1, 1]
)
def forward(self, x):
"""
Forward pass through the autoencoder.
Args:
x: Input tensor of shape (batch_size, 3, 128, 128)
Returns:
reconstructed: Reconstructed image of shape (batch_size, 3, 128, 128)
latent: Latent representation of shape (batch_size, latent_dim)
"""
# Encode
x = self.encoder(x)
x = x.view(x.size(0), -1)
latent = self.fc_encoder(x)
# Decode
x = self.fc_decoder(latent)
x = x.view(x.size(0), 512, 4, 4)
reconstructed = self.decoder(x)
return reconstructed, latent
def encode(self, x):
"""Extract latent representation only"""
x = self.encoder(x)
x = x.view(x.size(0), -1)
latent = self.fc_encoder(x)
return latent
def decode(self, latent):
"""Reconstruct from latent representation"""
x = self.fc_decoder(latent)
x = x.view(x.size(0), 512, 4, 4)
reconstructed = self.decoder(x)
return reconstructed
def reconstruction_error(self, x, reduction='mean'):
"""
Calculate per-sample reconstruction error (MSE).
Useful for anomaly/deepfake detection.
Args:
x: Input tensor
reduction: 'mean' for average error, 'none' for per-sample errors
Returns:
Reconstruction error (MSE)
"""
reconstructed, _ = self.forward(x)
error = (reconstructed - x) ** 2
if reduction == 'mean':
return error.mean()
elif reduction == 'none':
return error.view(x.size(0), -1).mean(dim=1)
else:
raise ValueError(f"Unknown reduction: {reduction}")
def load_model(checkpoint_path, device='cuda'):
"""
Load pretrained model from checkpoint.
Args:
checkpoint_path: Path to .ckpt file
device: 'cuda' or 'cpu'
Returns:
model: Loaded ResidualConvAutoencoder in eval mode
"""
model = ResidualConvAutoencoder(latent_dim=512)
checkpoint = torch.load(checkpoint_path, map_location=device)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
return model