In [None]:
%cd ..

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import PIL.Image
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader

from realfake.data import DictDataset, get_augs
from realfake.models import RealFakeClassifier, RealFakeParams
from realfake.utils import find_latest_checkpoint, get_user_name, read_jsonl

In [None]:
def load_from_checkpoint(checkpoint_dir, name=None, map_location="cpu"):
 checkpoint_dir = Path(checkpoint_dir)
 path = find_latest_checkpoint(checkpoint_dir) if name is None else checkpoint_dir/name
 checkpoint = torch.load(path, map_location)
 params = RealFakeParams.parse_file(path.parent/"params.json")
 params.pretrained = False
 classifier = RealFakeClassifier(params)
 classifier.load_state_dict(checkpoint["state_dict"])
 classifier.eval()
 return classifier

In [None]:
model = load_from_checkpoint("checkpoints/convnext_large_2m_e5")

In [None]:
real = [{"path": str(p), "label": "real"} for p in Path("imagenet_val").iterdir()]
fake = [{"path": str(p), "label": "fake"} for p in Path("fakes").glob("**/*.png")]
data = real + fake
len(data)

In [None]:
batch_size = 128
scores = []

with torch.inference_mode():
 ds = DictDataset(data, get_augs(train=False))
 dl = DataLoader(ds, batch_size=batch_size, num_workers=8, shuffle=False)

 for batch in tqdm(dl):
 _, logits, y_true_onehot = model(batch)
 probs = logits.softmax(dim=1)
 y_true = y_true_onehot.argmax(dim=1)
 y_pred = probs.argmax(dim=1)
 matched = y_true == y_pred
 
 scores += [
 {"fake_prob": fake_prob.item(), "match": match.item()}
 for fake_prob, match in zip(probs[:, 1], matched)
 ]
 
scores = pd.DataFrame(scores)
scores["label"] = [r["label"] for r in data]
scores["path"] = [r["path"] for r in data]

In [None]:
def view_results(df: pd.DataFrame, 
 query: str, 
 img_size: int = 256, 
 plot_size: int = 4,
 n_rows: int = 5,
 n_cols: int = 5):
 
 f, axes = plt.subplots(n_rows, n_cols, 
 figsize=(n_cols*plot_size, n_rows*plot_size), 
 gridspec_kw={"hspace": 0.1, "wspace": 0})
 
 f.subplots_adjust(hspace=0, wspace=0)
 
 sz = img_size
 
 items = (df.sort_values(by="fake_prob")
 .reset_index(drop=True)
 .query(query)
 .apply(lambda rec: (
 PIL.Image.open(rec.path).resize((sz,sz)), 
 rec.fake_prob), axis=1)
 .path.tolist())

 for ax, (im, score) in zip(axes.flat, items):
 ax.imshow(im)
 ax.set_title(f"P(fake)={score:2.2%}")
 ax.set_axis_off()
 ax.set_aspect("equal")

In [None]:
view_results(scores, "label == 'fake' and fake_prob >= 0.8")

In [None]:
view_results(scores, "label == 'fake' and fake_prob < 0.5")