devforfu commited on
Commit
7447212
1 Parent(s): 9c92cbf

Prediction confidence analysis

Browse files
Files changed (1) hide show
  1. nbs/inference.ipynb +77 -96
nbs/inference.ipynb CHANGED
@@ -2,38 +2,34 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "id": "fd1758c9-040d-4727-96d8-951d385ba277",
7
  "metadata": {
8
  "tags": []
9
  },
10
- "outputs": [
11
- {
12
- "name": "stdout",
13
- "output_type": "stream",
14
- "text": [
15
- "/admin/home-devforfu/realfake\n"
16
- ]
17
- }
18
- ],
19
  "source": [
20
  "%cd .."
21
  ]
22
  },
23
  {
24
  "cell_type": "code",
25
- "execution_count": 98,
26
  "id": "eb53a9fc-90eb-4658-b24c-f6f33c731235",
27
  "metadata": {
28
  "tags": []
29
  },
30
  "outputs": [],
31
  "source": [
32
- "import random\n",
33
  "from pathlib import Path\n",
 
 
 
 
34
  "import torch\n",
35
  "from tqdm import tqdm\n",
36
  "from torch.utils.data import DataLoader\n",
 
37
  "from realfake.data import DictDataset, get_augs\n",
38
  "from realfake.models import RealFakeClassifier, RealFakeParams\n",
39
  "from realfake.utils import find_latest_checkpoint, get_user_name, read_jsonl"
@@ -41,7 +37,7 @@
41
  },
42
  {
43
  "cell_type": "code",
44
- "execution_count": 68,
45
  "id": "41a0c3c5-d01c-46ef-b50b-25d9c297a432",
46
  "metadata": {
47
  "tags": []
@@ -62,7 +58,7 @@
62
  },
63
  {
64
  "cell_type": "code",
65
- "execution_count": 69,
66
  "id": "14ca2c34-99f0-4088-99d8-9b0cd008097a",
67
  "metadata": {
68
  "tags": []
@@ -74,128 +70,113 @@
74
  },
75
  {
76
  "cell_type": "code",
77
- "execution_count": 109,
78
- "id": "fb594a6e-0f75-4d06-a65e-68dda1d39ef6",
79
- "metadata": {
80
- "tags": []
81
- },
82
- "outputs": [
83
- {
84
- "data": {
85
- "text/plain": [
86
- "2504"
87
- ]
88
- },
89
- "execution_count": 109,
90
- "metadata": {},
91
- "output_type": "execute_result"
92
- }
93
- ],
94
- "source": [
95
- "fake = [{\"path\": str(fn), \"label\": \"fake\"} for fn in Path(\"fakes\").glob(\"**/*.png\")]\n",
96
- "len(fake)"
97
- ]
98
- },
99
- {
100
- "cell_type": "code",
101
- "execution_count": 110,
102
- "id": "da8ae077-53d7-494a-b330-9d674c257734",
103
  "metadata": {
104
  "tags": []
105
  },
106
  "outputs": [],
107
  "source": [
108
- "imagenet_validation = list(Path(f\"/fsx/{get_user_name()}/data/imagenet-1k/validation\").glob(\"*.JPEG\"))"
 
 
 
109
  ]
110
  },
111
  {
112
  "cell_type": "code",
113
- "execution_count": 111,
114
- "id": "59c21450-c7bb-479e-8d07-5861dca3dac7",
115
  "metadata": {
116
  "tags": []
117
  },
118
  "outputs": [],
119
  "source": [
120
- "real = [{\"path\": str(fn), \"label\": \"real\"} for fn in random.choices(imagenet_validation, k=len(fakes))]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  ]
122
  },
123
  {
124
  "cell_type": "code",
125
- "execution_count": 113,
126
- "id": "1d445977-e1f1-40c2-8975-1eb1f7c04493",
127
  "metadata": {
128
  "tags": []
129
  },
130
  "outputs": [],
131
  "source": [
132
- "records = fake + real"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  ]
134
  },
135
  {
136
  "cell_type": "code",
137
- "execution_count": 117,
138
- "id": "3a25ded0-c284-4148-b29a-25897cce9c5b",
139
  "metadata": {
140
  "tags": []
141
  },
142
  "outputs": [],
143
  "source": [
144
- "random.shuffle(records)"
145
- ]
146
- },
147
- {
148
- "cell_type": "code",
149
- "execution_count": 120,
150
- "id": "7b76d082-9e35-455d-8c86-0300ffa224d0",
151
- "metadata": {
152
- "tags": []
153
- },
154
- "outputs": [
155
- {
156
- "name": "stderr",
157
- "output_type": "stream",
158
- "text": [
159
- "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [03:45<00:00, 1.44s/it]\n"
160
- ]
161
- }
162
- ],
163
- "source": [
164
- "batch_size = 128\n",
165
- " \n",
166
- "with torch.inference_mode():\n",
167
- " ds = DictDataset(records, get_augs(train=False))\n",
168
- " dl = DataLoader(ds, batch_size=32, num_workers=8, shuffle=False)\n",
169
- "\n",
170
- " matched, total = 0, len(ds)\n",
171
- "\n",
172
- " for batch in tqdm(dl):\n",
173
- " _, logits, y_true_onehot = model(batch)\n",
174
- " y_true = y_true_onehot.argmax(dim=1)\n",
175
- " y_pred = logits.softmax(dim=1).argmax(dim=1)\n",
176
- " equals = y_true == y_pred\n",
177
- " # print(equals.float().mean())\n",
178
- " matched += equals.sum().item()"
179
  ]
180
  },
181
  {
182
  "cell_type": "code",
183
- "execution_count": 121,
184
- "id": "6077a3f7-4ad5-4fc0-bc53-7e939b64a658",
185
  "metadata": {
186
  "tags": []
187
  },
188
- "outputs": [
189
- {
190
- "name": "stdout",
191
- "output_type": "stream",
192
- "text": [
193
- "Accuracy: 99.58%\n"
194
- ]
195
- }
196
- ],
197
  "source": [
198
- "print(f\"Accuracy: {matched/total:2.2%}\")"
199
  ]
200
  }
201
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "id": "fd1758c9-040d-4727-96d8-951d385ba277",
7
  "metadata": {
8
  "tags": []
9
  },
10
+ "outputs": [],
 
 
 
 
 
 
 
 
11
  "source": [
12
  "%cd .."
13
  ]
14
  },
15
  {
16
  "cell_type": "code",
17
+ "execution_count": null,
18
  "id": "eb53a9fc-90eb-4658-b24c-f6f33c731235",
19
  "metadata": {
20
  "tags": []
21
  },
22
  "outputs": [],
23
  "source": [
 
24
  "from pathlib import Path\n",
25
+ "\n",
26
+ "import matplotlib.pyplot as plt\n",
27
+ "import pandas as pd\n",
28
+ "import PIL.Image\n",
29
  "import torch\n",
30
  "from tqdm import tqdm\n",
31
  "from torch.utils.data import DataLoader\n",
32
+ "\n",
33
  "from realfake.data import DictDataset, get_augs\n",
34
  "from realfake.models import RealFakeClassifier, RealFakeParams\n",
35
  "from realfake.utils import find_latest_checkpoint, get_user_name, read_jsonl"
 
37
  },
38
  {
39
  "cell_type": "code",
40
+ "execution_count": null,
41
  "id": "41a0c3c5-d01c-46ef-b50b-25d9c297a432",
42
  "metadata": {
43
  "tags": []
 
58
  },
59
  {
60
  "cell_type": "code",
61
+ "execution_count": null,
62
  "id": "14ca2c34-99f0-4088-99d8-9b0cd008097a",
63
  "metadata": {
64
  "tags": []
 
70
  },
71
  {
72
  "cell_type": "code",
73
+ "execution_count": null,
74
+ "id": "d3445075-07ba-424c-9849-26c9de6ce1a8",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  "metadata": {
76
  "tags": []
77
  },
78
  "outputs": [],
79
  "source": [
80
+ "real = [{\"path\": str(p), \"label\": \"real\"} for p in Path(\"imagenet_val\").iterdir()]\n",
81
+ "fake = [{\"path\": str(p), \"label\": \"fake\"} for p in Path(\"fakes\").glob(\"**/*.png\")]\n",
82
+ "data = real + fake\n",
83
+ "len(data)"
84
  ]
85
  },
86
  {
87
  "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "7b76d082-9e35-455d-8c86-0300ffa224d0",
90
  "metadata": {
91
  "tags": []
92
  },
93
  "outputs": [],
94
  "source": [
95
+ "batch_size = 128\n",
96
+ "scores = []\n",
97
+ "\n",
98
+ "with torch.inference_mode():\n",
99
+ " ds = DictDataset(data, get_augs(train=False))\n",
100
+ " dl = DataLoader(ds, batch_size=batch_size, num_workers=8, shuffle=False)\n",
101
+ "\n",
102
+ " for batch in tqdm(dl):\n",
103
+ " _, logits, y_true_onehot = model(batch)\n",
104
+ " probs = logits.softmax(dim=1)\n",
105
+ " y_true = y_true_onehot.argmax(dim=1)\n",
106
+ " y_pred = probs.argmax(dim=1)\n",
107
+ " matched = y_true == y_pred\n",
108
+ " \n",
109
+ " scores += [\n",
110
+ " {\"fake_prob\": fake_prob.item(), \"match\": match.item()}\n",
111
+ " for fake_prob, match in zip(probs[:, 1], matched)\n",
112
+ " ]\n",
113
+ " \n",
114
+ "scores = pd.DataFrame(scores)\n",
115
+ "scores[\"label\"] = [r[\"label\"] for r in data]\n",
116
+ "scores[\"path\"] = [r[\"path\"] for r in data]"
117
  ]
118
  },
119
  {
120
  "cell_type": "code",
121
+ "execution_count": null,
122
+ "id": "9c358c74-845d-42f1-9fcb-673e2a90ef69",
123
  "metadata": {
124
  "tags": []
125
  },
126
  "outputs": [],
127
  "source": [
128
+ "def view_results(df: pd.DataFrame, \n",
129
+ " query: str, \n",
130
+ " img_size: int = 256, \n",
131
+ " plot_size: int = 4,\n",
132
+ " n_rows: int = 5,\n",
133
+ " n_cols: int = 5):\n",
134
+ " \n",
135
+ " f, axes = plt.subplots(n_rows, n_cols, \n",
136
+ " figsize=(n_cols*plot_size, n_rows*plot_size), \n",
137
+ " gridspec_kw={\"hspace\": 0.1, \"wspace\": 0})\n",
138
+ " \n",
139
+ " f.subplots_adjust(hspace=0, wspace=0)\n",
140
+ " \n",
141
+ " sz = img_size\n",
142
+ " \n",
143
+ " items = (df.sort_values(by=\"fake_prob\")\n",
144
+ " .reset_index(drop=True)\n",
145
+ " .query(query)\n",
146
+ " .apply(lambda rec: (\n",
147
+ " PIL.Image.open(rec.path).resize((sz,sz)), \n",
148
+ " rec.fake_prob), axis=1)\n",
149
+ " .path.tolist())\n",
150
+ "\n",
151
+ " for ax, (im, score) in zip(axes.flat, items):\n",
152
+ " ax.imshow(im)\n",
153
+ " ax.set_title(f\"P(fake)={score:2.2%}\")\n",
154
+ " ax.set_axis_off()\n",
155
+ " ax.set_aspect(\"equal\")"
156
  ]
157
  },
158
  {
159
  "cell_type": "code",
160
+ "execution_count": null,
161
+ "id": "aeb284f7-46a3-408c-afa4-158ba9640571",
162
  "metadata": {
163
  "tags": []
164
  },
165
  "outputs": [],
166
  "source": [
167
+ "view_results(scores, \"label == 'fake' and fake_prob >= 0.8\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  ]
169
  },
170
  {
171
  "cell_type": "code",
172
+ "execution_count": null,
173
+ "id": "f518ea74-3c97-4bb9-a461-73c108dac75f",
174
  "metadata": {
175
  "tags": []
176
  },
177
+ "outputs": [],
 
 
 
 
 
 
 
 
178
  "source": [
179
+ "view_results(scores, \"label == 'fake' and fake_prob < 0.5\")"
180
  ]
181
  }
182
  ],