|
import os |
|
import torch |
|
import torchvision.transforms as transforms |
|
import torch.utils.data |
|
import numpy as np |
|
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score |
|
import pickle |
|
from tqdm import tqdm |
|
from datetime import datetime |
|
from copy import deepcopy |
|
from dataset_paths import DATASET_PATHS |
|
import random |
|
|
|
from datasets import create_test_dataloader |
|
from utils.logger import create_logger |
|
import options |
|
from networks.validator import Validator |
|
|
|
|
|
SEED = 0 |
|
def set_seed(): |
|
torch.manual_seed(SEED) |
|
torch.cuda.manual_seed(SEED) |
|
np.random.seed(SEED) |
|
random.seed(SEED) |
|
|
|
|
|
MEAN = { |
|
"imagenet":[0.485, 0.456, 0.406], |
|
"clip":[0.48145466, 0.4578275, 0.40821073] |
|
} |
|
|
|
STD = { |
|
"imagenet":[0.229, 0.224, 0.225], |
|
"clip":[0.26862954, 0.26130258, 0.27577711] |
|
} |
|
|
|
|
|
|
|
def find_best_threshold(y_true, y_pred): |
|
"We assume first half is real 0, and the second half is fake 1" |
|
|
|
N = y_true.shape[0] |
|
|
|
if y_pred[0:N//2].max() <= y_pred[N//2:N].min(): |
|
return (y_pred[0:N//2].max() + y_pred[N//2:N].min()) / 2 |
|
|
|
best_acc = 0 |
|
best_thres = 0 |
|
for thres in y_pred: |
|
temp = deepcopy(y_pred) |
|
temp[temp>=thres] = 1 |
|
temp[temp<thres] = 0 |
|
|
|
acc = (temp == y_true).sum() / N |
|
if acc >= best_acc: |
|
best_thres = thres |
|
best_acc = acc |
|
|
|
return best_thres |
|
|
|
|
|
def calculate_acc(y_true, y_pred, thres): |
|
r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > thres) |
|
f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > thres) |
|
acc = accuracy_score(y_true, y_pred > thres) |
|
return r_acc, f_acc, acc |
|
|
|
|
|
def validate(model, loader, logger, find_thres=False): |
|
|
|
with torch.no_grad(): |
|
y_true, y_pred = [], [] |
|
logger.info ("Length of dataset: %d" %(len(loader))) |
|
pbar = tqdm(loader) |
|
for data in pbar: |
|
pbar.set_description(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) |
|
model.set_input(data) |
|
|
|
y_pred.extend(model.model(model.input).view(-1).unsqueeze(1).sigmoid().flatten().tolist()) |
|
y_true.extend(data[1].flatten().tolist()) |
|
|
|
y_true, y_pred = np.array(y_true), np.array(y_pred) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ap = average_precision_score(y_true, y_pred) |
|
|
|
|
|
r_acc0, f_acc0, acc0 = calculate_acc(y_true, y_pred, 0.5) |
|
if not find_thres: |
|
return ap, r_acc0, f_acc0, acc0 |
|
|
|
|
|
|
|
best_thres = find_best_threshold(y_true, y_pred) |
|
r_acc1, f_acc1, acc1 = calculate_acc(y_true, y_pred, best_thres) |
|
|
|
return ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres |
|
|
|
|
|
|
|
|
|
|
|
def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg", "bmp"]): |
|
out = [] |
|
for r, d, f in os.walk(rootdir): |
|
for file in f: |
|
if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): |
|
out.append(os.path.join(r, file)) |
|
return out |
|
|
|
|
|
def get_list(path, must_contain=''): |
|
if ".pickle" in path: |
|
with open(path, 'rb') as f: |
|
image_list = pickle.load(f) |
|
image_list = [ item for item in image_list if must_contain in item ] |
|
else: |
|
image_list = recursively_read(path, must_contain) |
|
return image_list |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
val_opt = options.TestOptions().parse() |
|
|
|
output_dir=os.path.join(val_opt.output, val_opt.name) |
|
os.makedirs(output_dir, exist_ok=True) |
|
logger = create_logger(output_dir=output_dir, name="FakeVideoDetector") |
|
logger.info(f"working dir: {output_dir}") |
|
|
|
model = Validator(val_opt) |
|
model.load_state_dict(val_opt.ckpt) |
|
logger.info("ckpt loaded!") |
|
|
|
val_loader = create_test_dataloader(val_opt, clip_model = None, transform = model.clip_model.preprocess) |
|
|
|
|
|
ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres = validate(model, val_loader, logger, find_thres=True, ) |
|
|
|
print(f"ap: {ap}, r_acc0: {r_acc0}, f_acc0: {f_acc0}, acc0:{acc0}, r_acc1: {r_acc1}, f_acc1: {f_acc1}, acc1: {acc1}, best_thres: {best_thres} ") |
|
|
|
with open( os.path.join(val_opt.name,'ap.txt'), 'a') as f: |
|
f.write(str(round(ap*100, 2))+'\n' ) |
|
|
|
with open( os.path.join(val_opt.name,'acc0.txt'), 'a') as f: |
|
f.write(str(round(r_acc0*100, 2))+' '+str(round(f_acc0*100, 2))+' '+str(round(acc0*100, 2))+'\n' ) |
|
|
|
|