{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "fd1758c9-040d-4727-96d8-951d385ba277", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/admin/home-devforfu/realfake\n" ] } ], "source": [ "%cd .." ] }, { "cell_type": "code", "execution_count": 98, "id": "eb53a9fc-90eb-4658-b24c-f6f33c731235", "metadata": { "tags": [] }, "outputs": [], "source": [ "import random\n", "from pathlib import Path\n", "import torch\n", "from tqdm import tqdm\n", "from torch.utils.data import DataLoader\n", "from realfake.data import DictDataset, get_augs\n", "from realfake.models import RealFakeClassifier, RealFakeParams\n", "from realfake.utils import find_latest_checkpoint, get_user_name, read_jsonl" ] }, { "cell_type": "code", "execution_count": 68, "id": "41a0c3c5-d01c-46ef-b50b-25d9c297a432", "metadata": { "tags": [] }, "outputs": [], "source": [ "def load_from_checkpoint(checkpoint_dir, name=None, map_location=\"cpu\"):\n", " checkpoint_dir = Path(checkpoint_dir)\n", " path = find_latest_checkpoint(checkpoint_dir) if name is None else checkpoint_dir/name\n", " checkpoint = torch.load(path, map_location)\n", " params = RealFakeParams.parse_file(path.parent/\"params.json\")\n", " params.pretrained = False\n", " classifier = RealFakeClassifier(params)\n", " classifier.load_state_dict(checkpoint[\"state_dict\"])\n", " classifier.eval()\n", " return classifier" ] }, { "cell_type": "code", "execution_count": 69, "id": "14ca2c34-99f0-4088-99d8-9b0cd008097a", "metadata": { "tags": [] }, "outputs": [], "source": [ "model = load_from_checkpoint(\"checkpoints/convnext_large_2m_e5\")" ] }, { "cell_type": "code", "execution_count": 109, "id": "fb594a6e-0f75-4d06-a65e-68dda1d39ef6", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "2504" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fake = [{\"path\": str(fn), \"label\": \"fake\"} for fn in Path(\"fakes\").glob(\"**/*.png\")]\n", "len(fake)" ] }, { "cell_type": "code", "execution_count": 110, "id": "da8ae077-53d7-494a-b330-9d674c257734", "metadata": { "tags": [] }, "outputs": [], "source": [ "imagenet_validation = list(Path(f\"/fsx/{get_user_name()}/data/imagenet-1k/validation\").glob(\"*.JPEG\"))" ] }, { "cell_type": "code", "execution_count": 111, "id": "59c21450-c7bb-479e-8d07-5861dca3dac7", "metadata": { "tags": [] }, "outputs": [], "source": [ "real = [{\"path\": str(fn), \"label\": \"real\"} for fn in random.choices(imagenet_validation, k=len(fakes))]" ] }, { "cell_type": "code", "execution_count": 113, "id": "1d445977-e1f1-40c2-8975-1eb1f7c04493", "metadata": { "tags": [] }, "outputs": [], "source": [ "records = fake + real" ] }, { "cell_type": "code", "execution_count": 117, "id": "3a25ded0-c284-4148-b29a-25897cce9c5b", "metadata": { "tags": [] }, "outputs": [], "source": [ "random.shuffle(records)" ] }, { "cell_type": "code", "execution_count": 120, "id": "7b76d082-9e35-455d-8c86-0300ffa224d0", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [03:45<00:00, 1.44s/it]\n" ] } ], "source": [ "batch_size = 128\n", " \n", "with torch.inference_mode():\n", " ds = DictDataset(records, get_augs(train=False))\n", " dl = DataLoader(ds, batch_size=32, num_workers=8, shuffle=False)\n", "\n", " matched, total = 0, len(ds)\n", "\n", " for batch in tqdm(dl):\n", " _, logits, y_true_onehot = model(batch)\n", " y_true = y_true_onehot.argmax(dim=1)\n", " y_pred = logits.softmax(dim=1).argmax(dim=1)\n", " equals = y_true == y_pred\n", " # print(equals.float().mean())\n", " matched += equals.sum().item()" ] }, { "cell_type": "code", "execution_count": 121, "id": "6077a3f7-4ad5-4fc0-bc53-7e939b64a658", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 99.58%\n" ] } ], "source": [ "print(f\"Accuracy: {matched/total:2.2%}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }