{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "fd1758c9-040d-4727-96d8-951d385ba277", "metadata": { "tags": [] }, "outputs": [], "source": [ "%cd .." ] }, { "cell_type": "code", "execution_count": null, "id": "eb53a9fc-90eb-4658-b24c-f6f33c731235", "metadata": { "tags": [] }, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import PIL.Image\n", "import torch\n", "from tqdm import tqdm\n", "from torch.utils.data import DataLoader\n", "\n", "from realfake.data import DictDataset, get_augs\n", "from realfake.models import RealFakeClassifier, RealFakeParams\n", "from realfake.utils import find_latest_checkpoint, get_user_name, read_jsonl" ] }, { "cell_type": "code", "execution_count": null, "id": "41a0c3c5-d01c-46ef-b50b-25d9c297a432", "metadata": { "tags": [] }, "outputs": [], "source": [ "def load_from_checkpoint(checkpoint_dir, name=None, map_location=\"cpu\"):\n", " checkpoint_dir = Path(checkpoint_dir)\n", " path = find_latest_checkpoint(checkpoint_dir) if name is None else checkpoint_dir/name\n", " checkpoint = torch.load(path, map_location)\n", " params = RealFakeParams.parse_file(path.parent/\"params.json\")\n", " params.pretrained = False\n", " classifier = RealFakeClassifier(params)\n", " classifier.load_state_dict(checkpoint[\"state_dict\"])\n", " classifier.eval()\n", " return classifier" ] }, { "cell_type": "code", "execution_count": null, "id": "14ca2c34-99f0-4088-99d8-9b0cd008097a", "metadata": { "tags": [] }, "outputs": [], "source": [ "model = load_from_checkpoint(\"checkpoints/convnext_large_2m_e5\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d3445075-07ba-424c-9849-26c9de6ce1a8", "metadata": { "tags": [] }, "outputs": [], "source": [ "real = [{\"path\": str(p), \"label\": \"real\"} for p in Path(\"imagenet_val\").iterdir()]\n", "fake = [{\"path\": str(p), \"label\": \"fake\"} for p in Path(\"fakes\").glob(\"**/*.png\")]\n", "data = real + fake\n", "len(data)" ] }, { "cell_type": "code", "execution_count": null, "id": "7b76d082-9e35-455d-8c86-0300ffa224d0", "metadata": { "tags": [] }, "outputs": [], "source": [ "batch_size = 128\n", "scores = []\n", "\n", "with torch.inference_mode():\n", " ds = DictDataset(data, get_augs(train=False))\n", " dl = DataLoader(ds, batch_size=batch_size, num_workers=8, shuffle=False)\n", "\n", " for batch in tqdm(dl):\n", " _, logits, y_true_onehot = model(batch)\n", " probs = logits.softmax(dim=1)\n", " y_true = y_true_onehot.argmax(dim=1)\n", " y_pred = probs.argmax(dim=1)\n", " matched = y_true == y_pred\n", " \n", " scores += [\n", " {\"fake_prob\": fake_prob.item(), \"match\": match.item()}\n", " for fake_prob, match in zip(probs[:, 1], matched)\n", " ]\n", " \n", "scores = pd.DataFrame(scores)\n", "scores[\"label\"] = [r[\"label\"] for r in data]\n", "scores[\"path\"] = [r[\"path\"] for r in data]" ] }, { "cell_type": "code", "execution_count": null, "id": "9c358c74-845d-42f1-9fcb-673e2a90ef69", "metadata": { "tags": [] }, "outputs": [], "source": [ "def view_results(df: pd.DataFrame, \n", " query: str, \n", " img_size: int = 256, \n", " plot_size: int = 4,\n", " n_rows: int = 5,\n", " n_cols: int = 5):\n", " \n", " f, axes = plt.subplots(n_rows, n_cols, \n", " figsize=(n_cols*plot_size, n_rows*plot_size), \n", " gridspec_kw={\"hspace\": 0.1, \"wspace\": 0})\n", " \n", " f.subplots_adjust(hspace=0, wspace=0)\n", " \n", " sz = img_size\n", " \n", " items = (df.sort_values(by=\"fake_prob\")\n", " .reset_index(drop=True)\n", " .query(query)\n", " .apply(lambda rec: (\n", " PIL.Image.open(rec.path).resize((sz,sz)), \n", " rec.fake_prob), axis=1)\n", " .path.tolist())\n", "\n", " for ax, (im, score) in zip(axes.flat, items):\n", " ax.imshow(im)\n", " ax.set_title(f\"P(fake)={score:2.2%}\")\n", " ax.set_axis_off()\n", " ax.set_aspect(\"equal\")" ] }, { "cell_type": "code", "execution_count": null, "id": "aeb284f7-46a3-408c-afa4-158ba9640571", "metadata": { "tags": [] }, "outputs": [], "source": [ "view_results(scores, \"label == 'fake' and fake_prob >= 0.8\")" ] }, { "cell_type": "code", "execution_count": null, "id": "f518ea74-3c97-4bb9-a461-73c108dac75f", "metadata": { "tags": [] }, "outputs": [], "source": [ "view_results(scores, \"label == 'fake' and fake_prob < 0.5\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }