Jannat24's picture
up4
20aea05 verified
import io
import os
import shutil
import requests
import time
import numpy as np
from PIL import Image, ImageOps
from math import nan
import math
import pickle
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, ConcatDataset, DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.cuda.amp import autocast, GradScaler
import jax
import jax.numpy as jnp
import transformers
from transformers.modeling_flax_utils import FlaxPreTrainedModel
from vqgan_jax.modeling_flax_vqgan import VQModel
import gradio as gr
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Model_Z1(nn.Module):
def __init__(self):
super(Model_Z1, self).__init__()
self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1)
self.batchnorm = nn.BatchNorm2d(2048)
self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1)
self.batchnorm2 = nn.BatchNorm2d(256)
self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1)
self.batchnorm3 = nn.BatchNorm2d(1024)
self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1)
self.batchnorm4 = nn.BatchNorm2d(256)
self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
self.batchnorm5 = nn.BatchNorm2d(512)
self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
self.elu = nn.ELU()
def forward(self, x):
res = x
x = self.elu(self.conv1(x))
x = self.batchnorm(x)
x = self.elu(self.conv2(x)) + res
x = self.batchnorm2(x)
x = self.elu(self.conv3(x))
x = self.batchnorm3(x)
x = self.elu(self.conv4(x)) + res
x = self.batchnorm4(x)
x = self.elu(self.conv5(x))
x = self.batchnorm5(x)
out = self.elu(self.conv6(x)) + res
return out
class Model_Z(nn.Module):
def __init__(self):
super(Model_Z, self).__init__()
self.conv1 = nn.Conv2d(in_channels=256, out_channels=2048, kernel_size=3, padding=1)
self.batchnorm = nn.BatchNorm2d(2048)
self.conv2 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, padding=1)
self.batchnorm2 = nn.BatchNorm2d(256)
self.conv3 = nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=3, padding=1)
self.batchnorm3 = nn.BatchNorm2d(1024)
self.conv4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=3, padding=1)
self.batchnorm4 = nn.BatchNorm2d(256)
self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
self.batchnorm5 = nn.BatchNorm2d(512)
self.conv6 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
self.batchnorm6 = nn.BatchNorm2d(256)
self.conv7 = nn.Conv2d(in_channels=256, out_channels=448, kernel_size=3, padding=1)
self.batchnorm7 = nn.BatchNorm2d(448)
self.conv8 = nn.Conv2d(in_channels=448, out_channels=384, kernel_size=3, padding=1)
self.batchnorm8 = nn.BatchNorm2d(384)
self.conv9 = nn.Conv2d(in_channels=384, out_channels=320, kernel_size=3, padding=1)
self.batchnorm9 = nn.BatchNorm2d(320)
self.conv10 = nn.Conv2d(in_channels=320, out_channels=256, kernel_size=3, padding=1)
self.elu = nn.ELU()
def forward(self, x):
res = x
x = self.elu(self.conv1(x))
x = self.batchnorm(x)
x = self.elu(self.conv2(x)) + res
x = self.batchnorm2(x)
x = self.elu(self.conv3(x))
x = self.batchnorm3(x)
x = self.elu(self.conv4(x)) + res
x = self.batchnorm4(x)
x = self.elu(self.conv5(x))
x = self.batchnorm5(x)
x = self.elu(self.conv6(x)) + res
x = self.batchnorm6(x)
x = self.elu(self.conv7(x))
x = self.batchnorm7(x)
x = self.elu(self.conv8(x))
x = self.batchnorm8(x)
x = self.elu(self.conv9(x))
x = self.batchnorm9(x)
out = self.elu(self.conv10(x)) + res
return out
def tensor_jax(x):
if x.dim() == 3:
x = x. unsqueeze(0)
x_np = x.detach().permute(0, 2, 3, 1).cpu().numpy() # Convert from (N, C, H, W) to (N, H, W, C) and move to CPU
x_jax = jnp.array(x_np)
return x_jax
def jax_to_tensor(x):
x_tensor = torch.tensor(np.array(x),requires_grad=True).permute(0, 3, 1, 2).to(device) # Convert from (N, H, W, C) to (N, C, H, W)
return x_tensor
# Define the transform
transform = T.Compose([
T.Resize((256, 256)),
T.ToTensor()
])
def gen_sources(img):
model_name = "dalle-mini/vqgan_imagenet_f16_16384"
model_vaq = VQModel.from_pretrained(model_name)
model_z1 = Model_Z1()
model_z1 = model_z1.to(device)
model_z1.load_state_dict(torch.load("./model_z1.pth",map_location=device))
model_z2 = Model_Z()
model_z2 = model_z2.to(device)
model_z2.load_state_dict(torch.load("./model_z2.pth",map_location=device))
model_zdf = Model_Z()
model_zdf = model_zdf.to(device)
model_zdf.load_state_dict(torch.load("./model_zdf.pth",map_location=device))
criterion = nn.MSELoss()
model_z1.eval()
model_z2.eval()
model_zdf.eval()
with torch.no_grad():
img = img.convert('RGB')
df_img = transform(img)
df_img = df_img.unsqueeze(0) # Change shape to (1, 3, 256, 256)
df_img = df_img.to(device)
#convert images: tensor --> jax_array
df_img_jax = tensor_jax(df_img)
#calculate quantized_code(z) for all images
z_df,_ = model_vaq.encode(df_img_jax)
#convert quantized_code(z): jax_array --> tensor
z_df_tensor = jax_to_tensor(z_df)
##----------------------------------------------------------------------
##----------------------model_z1-----------------------
outputs_z1 = model_z1(z_df_tensor)
#generate img1
z1_rec_jax = tensor_jax(outputs_z1)
rec_img1 = model_vaq.decode(z1_rec_jax)
##----------------------------------------------------------------------
##----------------------model_z2-----------------------
outputs_z2 = model_z2(z_df_tensor)
#generate img2
z2_rec_jax = tensor_jax(outputs_z2)
rec_img2 = model_vaq.decode(z2_rec_jax)
##----------------------------------------------------------------------
##----------------------model_zdf-----------------------
z_rec = outputs_z1 + outputs_z2
outputs_zdf = model_zdf(z_rec)
lossdf = criterion(outputs_zdf, z_df_tensor)
#calculate dfimg reconstruction loss
zdf_rec_jax = tensor_jax(outputs_zdf)
rec_df = model_vaq.decode(zdf_rec_jax)
rec_df_tensor = jax_to_tensor(rec_df)
dfimgloss = criterion(rec_df_tensor, df_img)
# Convert tensor back to a PIL image
rec_img1 = jax_to_tensor(rec_img1)
rec_img1 = rec_img1.squeeze(0)
rec_img2 = jax_to_tensor(rec_img2)
rec_img2 = rec_img2.squeeze(0)
rec_df = jax_to_tensor(rec_df)
rec_df = rec_df.squeeze(0)
rec_img1_pil = T.ToPILImage()(rec_img1)
rec_img2_pil = T.ToPILImage()(rec_img2)
rec_df_pil = T.ToPILImage()(rec_df)
return (rec_img1_pil, rec_img2_pil, round(dfimgloss.item(),3))
# Create the Gradio interface
interface = gr.Interface(
fn=gen_sources,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=[
gr.Image(type="pil", label="Source Image 1"),
gr.Image(type="pil", label="Source Image 2"),
#gr.Image(type="pil", label="Deepfake Image"),
gr.Number(label="Reconstruction Loss")
],
examples = ["./df1.jpg","./df2.jpg","./df3.jpg","./df4.jpg"],
theme = gr.themes.Soft(),
title="Uncovering Deepfake Image",
description="Upload an image.",
)
interface.launch()