|
import io |
|
import os |
|
import shutil |
|
import requests |
|
import time |
|
import numpy as np |
|
from PIL import Image, ImageOps |
|
from math import nan |
|
import math |
|
import pickle |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
|
|
from torch.utils.data import Dataset, ConcatDataset, DataLoader |
|
from torchvision.datasets import ImageFolder |
|
import torchvision.transforms as T |
|
import torchvision.transforms.functional as TF |
|
from torch.cuda.amp import autocast, GradScaler |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
import transformers |
|
from transformers.modeling_flax_utils import FlaxPreTrainedModel |
|
from vqgan_jax.modeling_flax_vqgan import VQModel |
|
|
|
import gradio as gr |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class Model_Z1(nn.Module): |
|
def __init__(self): |
|
super(Model_Z1, self).__init__() |
|
self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1) |
|
self.batchnorm = nn.BatchNorm2d(2048) |
|
self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1) |
|
self.batchnorm2 = nn.BatchNorm2d(256) |
|
self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1) |
|
self.batchnorm3 = nn.BatchNorm2d(1024) |
|
self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1) |
|
self.batchnorm4 = nn.BatchNorm2d(256) |
|
self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1) |
|
self.batchnorm5 = nn.BatchNorm2d(512) |
|
self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1) |
|
self.elu = nn.ELU() |
|
|
|
def forward(self, x): |
|
res = x |
|
x = self.elu(self.conv1(x)) |
|
x = self.batchnorm(x) |
|
x = self.elu(self.conv2(x)) + res |
|
x = self.batchnorm2(x) |
|
x = self.elu(self.conv3(x)) |
|
x = self.batchnorm3(x) |
|
x = self.elu(self.conv4(x)) + res |
|
x = self.batchnorm4(x) |
|
x = self.elu(self.conv5(x)) |
|
x = self.batchnorm5(x) |
|
out = self.elu(self.conv6(x)) + res |
|
return out |
|
|
|
class Model_Z(nn.Module): |
|
def __init__(self): |
|
super(Model_Z, self).__init__() |
|
self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1) |
|
self.batchnorm = nn.BatchNorm2d(2048) |
|
self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1) |
|
self.batchnorm2 = nn.BatchNorm2d(256) |
|
self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1) |
|
self.batchnorm3 = nn.BatchNorm2d(1024) |
|
self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1) |
|
self.batchnorm4 = nn.BatchNorm2d(256) |
|
self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1) |
|
self.batchnorm5 = nn.BatchNorm2d(512) |
|
self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1) |
|
self.batchnorm6 = nn.BatchNorm2d(256) |
|
self.conv7 = nn.Conv2d(in_channels=256, out_channels=448, kernel_size=3, padding=1) |
|
self.batchnorm7 = nn.BatchNorm2d(448) |
|
self.conv8 = nn.Conv2d(in_channels=448, out_channels=384, kernel_size=3, padding=1) |
|
self.batchnorm8 = nn.BatchNorm2d(384) |
|
self.conv9 = nn.Conv2d(in_channels=384, out_channels=320, kernel_size=3, padding=1) |
|
self.batchnorm9 = nn.BatchNorm2d(320) |
|
self.conv10 = nn.Conv2d(in_channels=320, out_channels=256, kernel_size=3, padding=1) |
|
self.elu = nn.ELU() |
|
|
|
def forward(self, x): |
|
res = x |
|
x = self.elu(self.conv1(x)) |
|
x = self.batchnorm(x) |
|
x = self.elu(self.conv2(x)) + res |
|
x = self.batchnorm2(x) |
|
x = self.elu(self.conv3(x)) |
|
x = self.batchnorm3(x) |
|
x = self.elu(self.conv4(x)) + res |
|
x = self.batchnorm4(x) |
|
x = self.elu(self.conv5(x)) |
|
x = self.batchnorm5(x) |
|
x = self.elu(self.conv6(x)) + res |
|
x = self.batchnorm6(x) |
|
x = self.elu(self.conv7(x)) |
|
x = self.batchnorm7(x) |
|
x = self.elu(self.conv8(x)) |
|
x = self.batchnorm8(x) |
|
x = self.elu(self.conv9(x)) |
|
x = self.batchnorm9(x) |
|
out = self.elu(self.conv10(x)) + res |
|
return out |
|
|
|
|
|
def tensor_jax(x): |
|
if x.dim() == 3: |
|
x = x. unsqueeze(0) |
|
|
|
x_np = x.detach().permute(0, 2, 3, 1).cpu().numpy() |
|
x_jax = jnp.array(x_np) |
|
return x_jax |
|
|
|
def jax_to_tensor(x): |
|
x_tensor = torch.tensor(np.array(x),requires_grad=True).permute(0, 3, 1, 2).to(device) |
|
return x_tensor |
|
|
|
|
|
transform = T.Compose([ |
|
T.Resize((256, 256)), |
|
T.ToTensor() |
|
]) |
|
|
|
def gen_sources(img): |
|
model_name = "dalle-mini/vqgan_imagenet_f16_16384" |
|
model_vaq = VQModel.from_pretrained(model_name) |
|
|
|
model_z1 = Model_Z1() |
|
model_z1 = model_z1.to(device) |
|
model_z1.load_state_dict(torch.load("./model_z1.pth",map_location=device)) |
|
|
|
model_z2 = Model_Z() |
|
model_z2 = model_z2.to(device) |
|
model_z2.load_state_dict(torch.load("./model_z2.pth",map_location=device)) |
|
|
|
model_zdf = Model_Z() |
|
model_zdf = model_zdf.to(device) |
|
model_zdf.load_state_dict(torch.load("./model_zdf.pth",map_location=device)) |
|
|
|
criterion = nn.MSELoss() |
|
model_z1.eval() |
|
model_z2.eval() |
|
model_zdf.eval() |
|
|
|
with torch.no_grad(): |
|
img = img.convert('RGB') |
|
df_img = transform(img) |
|
df_img = df_img.unsqueeze(0) |
|
df_img = df_img.to(device) |
|
|
|
df_img_jax = tensor_jax(df_img) |
|
|
|
z_df,_ = model_vaq.encode(df_img_jax) |
|
|
|
z_df_tensor = jax_to_tensor(z_df) |
|
|
|
|
|
outputs_z1 = model_z1(z_df_tensor) |
|
|
|
z1_rec_jax = tensor_jax(outputs_z1) |
|
rec_img1 = model_vaq.decode(z1_rec_jax) |
|
|
|
|
|
outputs_z2 = model_z2(z_df_tensor) |
|
|
|
z2_rec_jax = tensor_jax(outputs_z2) |
|
rec_img2 = model_vaq.decode(z2_rec_jax) |
|
|
|
|
|
z_rec = outputs_z1 + outputs_z2 |
|
outputs_zdf = model_zdf(z_rec) |
|
lossdf = criterion(outputs_zdf, z_df_tensor) |
|
|
|
zdf_rec_jax = tensor_jax(outputs_zdf) |
|
rec_df = model_vaq.decode(zdf_rec_jax) |
|
rec_df_tensor = jax_to_tensor(rec_df) |
|
dfimgloss = criterion(rec_df_tensor, df_img) |
|
|
|
rec_img1 = jax_to_tensor(rec_img1) |
|
rec_img1 = rec_img1.squeeze(0) |
|
rec_img2 = jax_to_tensor(rec_img2) |
|
rec_img2 = rec_img2.squeeze(0) |
|
rec_df = jax_to_tensor(rec_df) |
|
rec_df = rec_df.squeeze(0) |
|
rec_img1_pil = T.ToPILImage()(rec_img1) |
|
rec_img2_pil = T.ToPILImage()(rec_img2) |
|
rec_df_pil = T.ToPILImage()(rec_df) |
|
|
|
return (rec_img1_pil, rec_img2_pil, round(dfimgloss.item(),3)) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=gen_sources, |
|
inputs=gr.Image(type="pil", label="Input Image"), |
|
outputs=[ |
|
gr.Image(type="pil", label="Source Image 1"), |
|
gr.Image(type="pil", label="Source Image 2"), |
|
|
|
gr.Number(label="Reconstruction Loss") |
|
], |
|
examples = ["./df1.jpg","./df2.jpg","./df3.jpg","./df4.jpg"], |
|
theme = gr.themes.Soft(), |
|
title="Uncovering Deepfake Image", |
|
description="Upload an image.", |
|
) |
|
|
|
interface.launch() |
|
|