devforfu
commited on
Commit
•
7447212
1
Parent(s):
9c92cbf
Prediction confidence analysis
Browse files- nbs/inference.ipynb +77 -96
nbs/inference.ipynb
CHANGED
@@ -2,38 +2,34 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"id": "fd1758c9-040d-4727-96d8-951d385ba277",
|
7 |
"metadata": {
|
8 |
"tags": []
|
9 |
},
|
10 |
-
"outputs": [
|
11 |
-
{
|
12 |
-
"name": "stdout",
|
13 |
-
"output_type": "stream",
|
14 |
-
"text": [
|
15 |
-
"/admin/home-devforfu/realfake\n"
|
16 |
-
]
|
17 |
-
}
|
18 |
-
],
|
19 |
"source": [
|
20 |
"%cd .."
|
21 |
]
|
22 |
},
|
23 |
{
|
24 |
"cell_type": "code",
|
25 |
-
"execution_count":
|
26 |
"id": "eb53a9fc-90eb-4658-b24c-f6f33c731235",
|
27 |
"metadata": {
|
28 |
"tags": []
|
29 |
},
|
30 |
"outputs": [],
|
31 |
"source": [
|
32 |
-
"import random\n",
|
33 |
"from pathlib import Path\n",
|
|
|
|
|
|
|
|
|
34 |
"import torch\n",
|
35 |
"from tqdm import tqdm\n",
|
36 |
"from torch.utils.data import DataLoader\n",
|
|
|
37 |
"from realfake.data import DictDataset, get_augs\n",
|
38 |
"from realfake.models import RealFakeClassifier, RealFakeParams\n",
|
39 |
"from realfake.utils import find_latest_checkpoint, get_user_name, read_jsonl"
|
@@ -41,7 +37,7 @@
|
|
41 |
},
|
42 |
{
|
43 |
"cell_type": "code",
|
44 |
-
"execution_count":
|
45 |
"id": "41a0c3c5-d01c-46ef-b50b-25d9c297a432",
|
46 |
"metadata": {
|
47 |
"tags": []
|
@@ -62,7 +58,7 @@
|
|
62 |
},
|
63 |
{
|
64 |
"cell_type": "code",
|
65 |
-
"execution_count":
|
66 |
"id": "14ca2c34-99f0-4088-99d8-9b0cd008097a",
|
67 |
"metadata": {
|
68 |
"tags": []
|
@@ -74,128 +70,113 @@
|
|
74 |
},
|
75 |
{
|
76 |
"cell_type": "code",
|
77 |
-
"execution_count":
|
78 |
-
"id": "
|
79 |
-
"metadata": {
|
80 |
-
"tags": []
|
81 |
-
},
|
82 |
-
"outputs": [
|
83 |
-
{
|
84 |
-
"data": {
|
85 |
-
"text/plain": [
|
86 |
-
"2504"
|
87 |
-
]
|
88 |
-
},
|
89 |
-
"execution_count": 109,
|
90 |
-
"metadata": {},
|
91 |
-
"output_type": "execute_result"
|
92 |
-
}
|
93 |
-
],
|
94 |
-
"source": [
|
95 |
-
"fake = [{\"path\": str(fn), \"label\": \"fake\"} for fn in Path(\"fakes\").glob(\"**/*.png\")]\n",
|
96 |
-
"len(fake)"
|
97 |
-
]
|
98 |
-
},
|
99 |
-
{
|
100 |
-
"cell_type": "code",
|
101 |
-
"execution_count": 110,
|
102 |
-
"id": "da8ae077-53d7-494a-b330-9d674c257734",
|
103 |
"metadata": {
|
104 |
"tags": []
|
105 |
},
|
106 |
"outputs": [],
|
107 |
"source": [
|
108 |
-
"
|
|
|
|
|
|
|
109 |
]
|
110 |
},
|
111 |
{
|
112 |
"cell_type": "code",
|
113 |
-
"execution_count":
|
114 |
-
"id": "
|
115 |
"metadata": {
|
116 |
"tags": []
|
117 |
},
|
118 |
"outputs": [],
|
119 |
"source": [
|
120 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
]
|
122 |
},
|
123 |
{
|
124 |
"cell_type": "code",
|
125 |
-
"execution_count":
|
126 |
-
"id": "
|
127 |
"metadata": {
|
128 |
"tags": []
|
129 |
},
|
130 |
"outputs": [],
|
131 |
"source": [
|
132 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
]
|
134 |
},
|
135 |
{
|
136 |
"cell_type": "code",
|
137 |
-
"execution_count":
|
138 |
-
"id": "
|
139 |
"metadata": {
|
140 |
"tags": []
|
141 |
},
|
142 |
"outputs": [],
|
143 |
"source": [
|
144 |
-
"
|
145 |
-
]
|
146 |
-
},
|
147 |
-
{
|
148 |
-
"cell_type": "code",
|
149 |
-
"execution_count": 120,
|
150 |
-
"id": "7b76d082-9e35-455d-8c86-0300ffa224d0",
|
151 |
-
"metadata": {
|
152 |
-
"tags": []
|
153 |
-
},
|
154 |
-
"outputs": [
|
155 |
-
{
|
156 |
-
"name": "stderr",
|
157 |
-
"output_type": "stream",
|
158 |
-
"text": [
|
159 |
-
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [03:45<00:00, 1.44s/it]\n"
|
160 |
-
]
|
161 |
-
}
|
162 |
-
],
|
163 |
-
"source": [
|
164 |
-
"batch_size = 128\n",
|
165 |
-
" \n",
|
166 |
-
"with torch.inference_mode():\n",
|
167 |
-
" ds = DictDataset(records, get_augs(train=False))\n",
|
168 |
-
" dl = DataLoader(ds, batch_size=32, num_workers=8, shuffle=False)\n",
|
169 |
-
"\n",
|
170 |
-
" matched, total = 0, len(ds)\n",
|
171 |
-
"\n",
|
172 |
-
" for batch in tqdm(dl):\n",
|
173 |
-
" _, logits, y_true_onehot = model(batch)\n",
|
174 |
-
" y_true = y_true_onehot.argmax(dim=1)\n",
|
175 |
-
" y_pred = logits.softmax(dim=1).argmax(dim=1)\n",
|
176 |
-
" equals = y_true == y_pred\n",
|
177 |
-
" # print(equals.float().mean())\n",
|
178 |
-
" matched += equals.sum().item()"
|
179 |
]
|
180 |
},
|
181 |
{
|
182 |
"cell_type": "code",
|
183 |
-
"execution_count":
|
184 |
-
"id": "
|
185 |
"metadata": {
|
186 |
"tags": []
|
187 |
},
|
188 |
-
"outputs": [
|
189 |
-
{
|
190 |
-
"name": "stdout",
|
191 |
-
"output_type": "stream",
|
192 |
-
"text": [
|
193 |
-
"Accuracy: 99.58%\n"
|
194 |
-
]
|
195 |
-
}
|
196 |
-
],
|
197 |
"source": [
|
198 |
-
"
|
199 |
]
|
200 |
}
|
201 |
],
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
"id": "fd1758c9-040d-4727-96d8-951d385ba277",
|
7 |
"metadata": {
|
8 |
"tags": []
|
9 |
},
|
10 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
"source": [
|
12 |
"%cd .."
|
13 |
]
|
14 |
},
|
15 |
{
|
16 |
"cell_type": "code",
|
17 |
+
"execution_count": null,
|
18 |
"id": "eb53a9fc-90eb-4658-b24c-f6f33c731235",
|
19 |
"metadata": {
|
20 |
"tags": []
|
21 |
},
|
22 |
"outputs": [],
|
23 |
"source": [
|
|
|
24 |
"from pathlib import Path\n",
|
25 |
+
"\n",
|
26 |
+
"import matplotlib.pyplot as plt\n",
|
27 |
+
"import pandas as pd\n",
|
28 |
+
"import PIL.Image\n",
|
29 |
"import torch\n",
|
30 |
"from tqdm import tqdm\n",
|
31 |
"from torch.utils.data import DataLoader\n",
|
32 |
+
"\n",
|
33 |
"from realfake.data import DictDataset, get_augs\n",
|
34 |
"from realfake.models import RealFakeClassifier, RealFakeParams\n",
|
35 |
"from realfake.utils import find_latest_checkpoint, get_user_name, read_jsonl"
|
|
|
37 |
},
|
38 |
{
|
39 |
"cell_type": "code",
|
40 |
+
"execution_count": null,
|
41 |
"id": "41a0c3c5-d01c-46ef-b50b-25d9c297a432",
|
42 |
"metadata": {
|
43 |
"tags": []
|
|
|
58 |
},
|
59 |
{
|
60 |
"cell_type": "code",
|
61 |
+
"execution_count": null,
|
62 |
"id": "14ca2c34-99f0-4088-99d8-9b0cd008097a",
|
63 |
"metadata": {
|
64 |
"tags": []
|
|
|
70 |
},
|
71 |
{
|
72 |
"cell_type": "code",
|
73 |
+
"execution_count": null,
|
74 |
+
"id": "d3445075-07ba-424c-9849-26c9de6ce1a8",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
"metadata": {
|
76 |
"tags": []
|
77 |
},
|
78 |
"outputs": [],
|
79 |
"source": [
|
80 |
+
"real = [{\"path\": str(p), \"label\": \"real\"} for p in Path(\"imagenet_val\").iterdir()]\n",
|
81 |
+
"fake = [{\"path\": str(p), \"label\": \"fake\"} for p in Path(\"fakes\").glob(\"**/*.png\")]\n",
|
82 |
+
"data = real + fake\n",
|
83 |
+
"len(data)"
|
84 |
]
|
85 |
},
|
86 |
{
|
87 |
"cell_type": "code",
|
88 |
+
"execution_count": null,
|
89 |
+
"id": "7b76d082-9e35-455d-8c86-0300ffa224d0",
|
90 |
"metadata": {
|
91 |
"tags": []
|
92 |
},
|
93 |
"outputs": [],
|
94 |
"source": [
|
95 |
+
"batch_size = 128\n",
|
96 |
+
"scores = []\n",
|
97 |
+
"\n",
|
98 |
+
"with torch.inference_mode():\n",
|
99 |
+
" ds = DictDataset(data, get_augs(train=False))\n",
|
100 |
+
" dl = DataLoader(ds, batch_size=batch_size, num_workers=8, shuffle=False)\n",
|
101 |
+
"\n",
|
102 |
+
" for batch in tqdm(dl):\n",
|
103 |
+
" _, logits, y_true_onehot = model(batch)\n",
|
104 |
+
" probs = logits.softmax(dim=1)\n",
|
105 |
+
" y_true = y_true_onehot.argmax(dim=1)\n",
|
106 |
+
" y_pred = probs.argmax(dim=1)\n",
|
107 |
+
" matched = y_true == y_pred\n",
|
108 |
+
" \n",
|
109 |
+
" scores += [\n",
|
110 |
+
" {\"fake_prob\": fake_prob.item(), \"match\": match.item()}\n",
|
111 |
+
" for fake_prob, match in zip(probs[:, 1], matched)\n",
|
112 |
+
" ]\n",
|
113 |
+
" \n",
|
114 |
+
"scores = pd.DataFrame(scores)\n",
|
115 |
+
"scores[\"label\"] = [r[\"label\"] for r in data]\n",
|
116 |
+
"scores[\"path\"] = [r[\"path\"] for r in data]"
|
117 |
]
|
118 |
},
|
119 |
{
|
120 |
"cell_type": "code",
|
121 |
+
"execution_count": null,
|
122 |
+
"id": "9c358c74-845d-42f1-9fcb-673e2a90ef69",
|
123 |
"metadata": {
|
124 |
"tags": []
|
125 |
},
|
126 |
"outputs": [],
|
127 |
"source": [
|
128 |
+
"def view_results(df: pd.DataFrame, \n",
|
129 |
+
" query: str, \n",
|
130 |
+
" img_size: int = 256, \n",
|
131 |
+
" plot_size: int = 4,\n",
|
132 |
+
" n_rows: int = 5,\n",
|
133 |
+
" n_cols: int = 5):\n",
|
134 |
+
" \n",
|
135 |
+
" f, axes = plt.subplots(n_rows, n_cols, \n",
|
136 |
+
" figsize=(n_cols*plot_size, n_rows*plot_size), \n",
|
137 |
+
" gridspec_kw={\"hspace\": 0.1, \"wspace\": 0})\n",
|
138 |
+
" \n",
|
139 |
+
" f.subplots_adjust(hspace=0, wspace=0)\n",
|
140 |
+
" \n",
|
141 |
+
" sz = img_size\n",
|
142 |
+
" \n",
|
143 |
+
" items = (df.sort_values(by=\"fake_prob\")\n",
|
144 |
+
" .reset_index(drop=True)\n",
|
145 |
+
" .query(query)\n",
|
146 |
+
" .apply(lambda rec: (\n",
|
147 |
+
" PIL.Image.open(rec.path).resize((sz,sz)), \n",
|
148 |
+
" rec.fake_prob), axis=1)\n",
|
149 |
+
" .path.tolist())\n",
|
150 |
+
"\n",
|
151 |
+
" for ax, (im, score) in zip(axes.flat, items):\n",
|
152 |
+
" ax.imshow(im)\n",
|
153 |
+
" ax.set_title(f\"P(fake)={score:2.2%}\")\n",
|
154 |
+
" ax.set_axis_off()\n",
|
155 |
+
" ax.set_aspect(\"equal\")"
|
156 |
]
|
157 |
},
|
158 |
{
|
159 |
"cell_type": "code",
|
160 |
+
"execution_count": null,
|
161 |
+
"id": "aeb284f7-46a3-408c-afa4-158ba9640571",
|
162 |
"metadata": {
|
163 |
"tags": []
|
164 |
},
|
165 |
"outputs": [],
|
166 |
"source": [
|
167 |
+
"view_results(scores, \"label == 'fake' and fake_prob >= 0.8\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
]
|
169 |
},
|
170 |
{
|
171 |
"cell_type": "code",
|
172 |
+
"execution_count": null,
|
173 |
+
"id": "f518ea74-3c97-4bb9-a461-73c108dac75f",
|
174 |
"metadata": {
|
175 |
"tags": []
|
176 |
},
|
177 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
"source": [
|
179 |
+
"view_results(scores, \"label == 'fake' and fake_prob < 0.5\")"
|
180 |
]
|
181 |
}
|
182 |
],
|