In [1]:
%cd ..

/admin/home-devforfu/realfake


In [98]:
import random
from pathlib import Path
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from realfake.data import DictDataset, get_augs
from realfake.models import RealFakeClassifier, RealFakeParams
from realfake.utils import find_latest_checkpoint, get_user_name, read_jsonl

In [68]:
def load_from_checkpoint(checkpoint_dir, name=None, map_location="cpu"):
 checkpoint_dir = Path(checkpoint_dir)
 path = find_latest_checkpoint(checkpoint_dir) if name is None else checkpoint_dir/name
 checkpoint = torch.load(path, map_location)
 params = RealFakeParams.parse_file(path.parent/"params.json")
 params.pretrained = False
 classifier = RealFakeClassifier(params)
 classifier.load_state_dict(checkpoint["state_dict"])
 classifier.eval()
 return classifier

In [69]:
model = load_from_checkpoint("checkpoints/convnext_large_2m_e5")

In [109]:
fake = [{"path": str(fn), "label": "fake"} for fn in Path("fakes").glob("**/*.png")]
len(fake)

2504

In [110]:
imagenet_validation = list(Path(f"/fsx/{get_user_name()}/data/imagenet-1k/validation").glob("*.JPEG"))

In [111]:
real = [{"path": str(fn), "label": "real"} for fn in random.choices(imagenet_validation, k=len(fakes))]

In [113]:
records = fake + real

In [117]:
random.shuffle(records)

In [120]:
batch_size = 128
 
with torch.inference_mode():
 ds = DictDataset(records, get_augs(train=False))
 dl = DataLoader(ds, batch_size=32, num_workers=8, shuffle=False)

 matched, total = 0, len(ds)

 for batch in tqdm(dl):
 _, logits, y_true_onehot = model(batch)
 y_true = y_true_onehot.argmax(dim=1)
 y_pred = logits.softmax(dim=1).argmax(dim=1)
 equals = y_true == y_pred
 # print(equals.float().mean())
 matched += equals.sum().item()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [03:45<00:00, 1.44s/it]


In [121]:
print(f"Accuracy: {matched/total:2.2%}")

Accuracy: 99.58%
