varma123 commited on
Commit
a4c2e30
1 Parent(s): ececc2a

Upload 2 files

Browse files
Files changed (2) hide show
  1. Deepfake_detection.ipynb +1089 -0
  2. requirements.txt +7 -0
Deepfake_detection.ipynb ADDED
@@ -0,0 +1,1089 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "a2220df6",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Import Libraries"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 4,
14
+ "id": "7249bea4",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import gradio as gr\n",
19
+ "import torch\n",
20
+ "import torch.nn.functional as F\n",
21
+ "from facenet_pytorch import MTCNN, InceptionResnetV1\n",
22
+ "import numpy as np\n",
23
+ "from PIL import Image\n",
24
+ "import cv2\n",
25
+ "from pytorch_grad_cam import GradCAM\n",
26
+ "from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget\n",
27
+ "from pytorch_grad_cam.utils.image import show_cam_on_image\n",
28
+ "import warnings\n",
29
+ "warnings.filterwarnings(\"ignore\")"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 2,
35
+ "id": "62f0492b-aad6-4464-ab96-1365b7f3a44e",
36
+ "metadata": {},
37
+ "outputs": [
38
+ {
39
+ "name": "stdout",
40
+ "output_type": "stream",
41
+ "text": [
42
+ "Requirement already satisfied: gradio in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (3.39.0)\n",
43
+ "Collecting gradio\n",
44
+ " Downloading gradio-4.19.1-py3-none-any.whl.metadata (15 kB)\n",
45
+ "Requirement already satisfied: aiofiles<24.0,>=22.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (23.2.1)\n",
46
+ "Requirement already satisfied: altair<6.0,>=4.2.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (5.2.0)\n",
47
+ "Requirement already satisfied: fastapi in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.109.2)\n",
48
+ "Requirement already satisfied: ffmpy in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.3.2)\n",
49
+ "Requirement already satisfied: gradio-client==0.10.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.10.0)\n",
50
+ "Requirement already satisfied: httpx in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.26.0)\n",
51
+ "Requirement already satisfied: huggingface-hub>=0.19.3 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.20.3)\n",
52
+ "Requirement already satisfied: importlib-resources<7.0,>=1.3 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (6.1.1)\n",
53
+ "Requirement already satisfied: jinja2<4.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (3.1.3)\n",
54
+ "Requirement already satisfied: markupsafe~=2.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (2.1.5)\n",
55
+ "Requirement already satisfied: matplotlib~=3.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (3.8.3)\n",
56
+ "Requirement already satisfied: numpy~=1.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (1.26.4)\n",
57
+ "Requirement already satisfied: orjson~=3.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (3.9.14)\n",
58
+ "Requirement already satisfied: packaging in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (23.2)\n",
59
+ "Requirement already satisfied: pandas<3.0,>=1.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (2.2.0)\n",
60
+ "Requirement already satisfied: pillow<11.0,>=8.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (9.4.0)\n",
61
+ "Requirement already satisfied: pydantic>=2.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (2.6.1)\n",
62
+ "Requirement already satisfied: pydub in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.25.1)\n",
63
+ "Requirement already satisfied: python-multipart>=0.0.9 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.0.9)\n",
64
+ "Requirement already satisfied: pyyaml<7.0,>=5.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (6.0.1)\n",
65
+ "Collecting ruff>=0.1.7 (from gradio)\n",
66
+ " Downloading ruff-0.2.1-py3-none-win_amd64.whl.metadata (23 kB)\n",
67
+ "Requirement already satisfied: semantic-version~=2.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (2.10.0)\n",
68
+ "Collecting tomlkit==0.12.0 (from gradio)\n",
69
+ " Downloading tomlkit-0.12.0-py3-none-any.whl.metadata (2.7 kB)\n",
70
+ "Collecting typer<1.0,>=0.9 (from typer[all]<1.0,>=0.9->gradio)\n",
71
+ " Downloading typer-0.9.0-py3-none-any.whl.metadata (14 kB)\n",
72
+ "Requirement already satisfied: typing-extensions~=4.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (4.9.0)\n",
73
+ "Requirement already satisfied: uvicorn>=0.14.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio) (0.27.1)\n",
74
+ "Requirement already satisfied: fsspec in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio-client==0.10.0->gradio) (2024.2.0)\n",
75
+ "Requirement already satisfied: websockets<12.0,>=10.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from gradio-client==0.10.0->gradio) (11.0.3)\n",
76
+ "Requirement already satisfied: jsonschema>=3.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from altair<6.0,>=4.2.0->gradio) (4.21.1)\n",
77
+ "Requirement already satisfied: toolz in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from altair<6.0,>=4.2.0->gradio) (0.12.1)\n",
78
+ "Requirement already satisfied: filelock in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from huggingface-hub>=0.19.3->gradio) (3.13.1)\n",
79
+ "Requirement already satisfied: requests in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from huggingface-hub>=0.19.3->gradio) (2.31.0)\n",
80
+ "Requirement already satisfied: tqdm>=4.42.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from huggingface-hub>=0.19.3->gradio) (4.66.2)\n",
81
+ "Requirement already satisfied: zipp>=3.1.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from importlib-resources<7.0,>=1.3->gradio) (3.17.0)\n",
82
+ "Requirement already satisfied: contourpy>=1.0.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from matplotlib~=3.0->gradio) (1.2.0)\n",
83
+ "Requirement already satisfied: cycler>=0.10 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from matplotlib~=3.0->gradio) (0.12.1)\n",
84
+ "Requirement already satisfied: fonttools>=4.22.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from matplotlib~=3.0->gradio) (4.49.0)\n",
85
+ "Requirement already satisfied: kiwisolver>=1.3.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from matplotlib~=3.0->gradio) (1.4.5)\n",
86
+ "Requirement already satisfied: pyparsing>=2.3.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from matplotlib~=3.0->gradio) (3.1.1)\n",
87
+ "Requirement already satisfied: python-dateutil>=2.7 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from matplotlib~=3.0->gradio) (2.8.2)\n",
88
+ "Requirement already satisfied: pytz>=2020.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from pandas<3.0,>=1.0->gradio) (2024.1)\n",
89
+ "Requirement already satisfied: tzdata>=2022.7 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from pandas<3.0,>=1.0->gradio) (2024.1)\n",
90
+ "Requirement already satisfied: annotated-types>=0.4.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from pydantic>=2.0->gradio) (0.6.0)\n",
91
+ "Requirement already satisfied: pydantic-core==2.16.2 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from pydantic>=2.0->gradio) (2.16.2)\n",
92
+ "Requirement already satisfied: click<9.0.0,>=7.1.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from typer<1.0,>=0.9->typer[all]<1.0,>=0.9->gradio) (8.1.7)\n",
93
+ "Requirement already satisfied: colorama<0.5.0,>=0.4.3 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from typer[all]<1.0,>=0.9->gradio) (0.4.6)\n",
94
+ "Collecting shellingham<2.0.0,>=1.3.0 (from typer[all]<1.0,>=0.9->gradio)\n",
95
+ " Downloading shellingham-1.5.4-py2.py3-none-any.whl.metadata (3.5 kB)\n",
96
+ "Collecting rich<14.0.0,>=10.11.0 (from typer[all]<1.0,>=0.9->gradio)\n",
97
+ " Downloading rich-13.7.0-py3-none-any.whl.metadata (18 kB)\n",
98
+ "Requirement already satisfied: h11>=0.8 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from uvicorn>=0.14.0->gradio) (0.14.0)\n",
99
+ "Requirement already satisfied: starlette<0.37.0,>=0.36.3 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from fastapi->gradio) (0.36.3)\n",
100
+ "Requirement already satisfied: anyio in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from httpx->gradio) (4.2.0)\n",
101
+ "Requirement already satisfied: certifi in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from httpx->gradio) (2024.2.2)\n",
102
+ "Requirement already satisfied: httpcore==1.* in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from httpx->gradio) (1.0.3)\n",
103
+ "Requirement already satisfied: idna in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from httpx->gradio) (3.6)\n",
104
+ "Requirement already satisfied: sniffio in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from httpx->gradio) (1.3.0)\n",
105
+ "Requirement already satisfied: attrs>=22.2.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (23.2.0)\n",
106
+ "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (2023.12.1)\n",
107
+ "Requirement already satisfied: referencing>=0.28.4 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.33.0)\n",
108
+ "Requirement already satisfied: rpds-py>=0.7.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.18.0)\n",
109
+ "Requirement already satisfied: six>=1.5 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio) (1.16.0)\n",
110
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from rich<14.0.0,>=10.11.0->typer[all]<1.0,>=0.9->gradio) (2.2.0)\n",
111
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from rich<14.0.0,>=10.11.0->typer[all]<1.0,>=0.9->gradio) (2.17.2)\n",
112
+ "Requirement already satisfied: exceptiongroup>=1.0.2 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from anyio->httpx->gradio) (1.2.0)\n",
113
+ "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from requests->huggingface-hub>=0.19.3->gradio) (3.3.2)\n",
114
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from requests->huggingface-hub>=0.19.3->gradio) (2.2.0)\n",
115
+ "Requirement already satisfied: mdurl~=0.1 in c:\\kandikits\\deepfake-detection\\deepfake-detection-env\\lib\\site-packages (from markdown-it-py>=2.2.0->rich<14.0.0,>=10.11.0->typer[all]<1.0,>=0.9->gradio) (0.1.2)\n",
116
+ "Downloading gradio-4.19.1-py3-none-any.whl (16.9 MB)\n",
117
+ " ---------------------------------------- 0.0/16.9 MB ? eta -:--:--\n",
118
+ " ---------------------------------------- 0.0/16.9 MB 1.9 MB/s eta 0:00:09\n",
119
+ " ---------------------------------------- 0.2/16.9 MB 2.4 MB/s eta 0:00:08\n",
120
+ " - -------------------------------------- 0.5/16.9 MB 3.9 MB/s eta 0:00:05\n",
121
+ " -- ------------------------------------- 1.1/16.9 MB 6.1 MB/s eta 0:00:03\n",
122
+ " ----- ---------------------------------- 2.2/16.9 MB 10.2 MB/s eta 0:00:02\n",
123
+ " -------- ------------------------------- 3.7/16.9 MB 13.9 MB/s eta 0:00:01\n",
124
+ " ----------- ---------------------------- 5.1/16.9 MB 15.4 MB/s eta 0:00:01\n",
125
+ " --------------- ------------------------ 6.6/16.9 MB 18.4 MB/s eta 0:00:01\n",
126
+ " ------------------- -------------------- 8.4/16.9 MB 20.6 MB/s eta 0:00:01\n",
127
+ " --------------------- ------------------ 9.2/16.9 MB 21.1 MB/s eta 0:00:01\n",
128
+ " --------------------- ------------------ 9.2/16.9 MB 21.1 MB/s eta 0:00:01\n",
129
+ " --------------------- ------------------ 9.2/16.9 MB 21.1 MB/s eta 0:00:01\n",
130
+ " --------------------- ------------------ 9.3/16.9 MB 16.5 MB/s eta 0:00:01\n",
131
+ " ------------------------ --------------- 10.3/16.9 MB 17.2 MB/s eta 0:00:01\n",
132
+ " ------------------------ --------------- 10.5/16.9 MB 17.3 MB/s eta 0:00:01\n",
133
+ " -------------------------- ------------- 11.1/16.9 MB 18.7 MB/s eta 0:00:01\n",
134
+ " ---------------------------- ----------- 12.1/16.9 MB 17.7 MB/s eta 0:00:01\n",
135
+ " -------------------------------- ------- 13.6/16.9 MB 18.2 MB/s eta 0:00:01\n",
136
+ " ----------------------------------- ---- 14.9/16.9 MB 18.2 MB/s eta 0:00:01\n",
137
+ " --------------------------------------- 16.6/16.9 MB 18.2 MB/s eta 0:00:01\n",
138
+ " ---------------------------------------- 16.9/16.9 MB 16.8 MB/s eta 0:00:00\n",
139
+ "Downloading tomlkit-0.12.0-py3-none-any.whl (37 kB)\n",
140
+ "Downloading ruff-0.2.1-py3-none-win_amd64.whl (7.4 MB)\n",
141
+ " ---------------------------------------- 0.0/7.4 MB ? eta -:--:--\n",
142
+ " -------- ------------------------------- 1.6/7.4 MB 51.9 MB/s eta 0:00:01\n",
143
+ " ---------------- ----------------------- 3.1/7.4 MB 40.2 MB/s eta 0:00:01\n",
144
+ " ------------------------- -------------- 4.8/7.4 MB 37.9 MB/s eta 0:00:01\n",
145
+ " --------------------------------- ------ 6.3/7.4 MB 36.5 MB/s eta 0:00:01\n",
146
+ " --------------------------------------- 7.4/7.4 MB 33.7 MB/s eta 0:00:01\n",
147
+ " ---------------------------------------- 7.4/7.4 MB 31.6 MB/s eta 0:00:00\n",
148
+ "Downloading typer-0.9.0-py3-none-any.whl (45 kB)\n",
149
+ " ---------------------------------------- 0.0/45.9 kB ? eta -:--:--\n",
150
+ " ---------------------------------------- 45.9/45.9 kB ? eta 0:00:00\n",
151
+ "Downloading rich-13.7.0-py3-none-any.whl (240 kB)\n",
152
+ " ---------------------------------------- 0.0/240.6 kB ? eta -:--:--\n",
153
+ " --------------------------------------- 240.6/240.6 kB 14.4 MB/s eta 0:00:00\n",
154
+ "Downloading shellingham-1.5.4-py2.py3-none-any.whl (9.8 kB)\n",
155
+ "Installing collected packages: tomlkit, shellingham, ruff, typer, rich, gradio\n",
156
+ " Attempting uninstall: gradio\n",
157
+ " Found existing installation: gradio 3.39.0\n",
158
+ " Uninstalling gradio-3.39.0:\n",
159
+ " Successfully uninstalled gradio-3.39.0\n",
160
+ "Successfully installed gradio-4.19.1 rich-13.7.0 ruff-0.2.1 shellingham-1.5.4 tomlkit-0.12.0 typer-0.9.0\n"
161
+ ]
162
+ }
163
+ ],
164
+ "source": [
165
+ "!pip install -U gradio"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "markdown",
170
+ "id": "d25e1c5d",
171
+ "metadata": {},
172
+ "source": [
173
+ "# Download and Load Model"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": 5,
179
+ "id": "237fbf44",
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": [
183
+ "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
184
+ "\n",
185
+ "mtcnn = MTCNN(\n",
186
+ " select_largest=False,\n",
187
+ " post_process=False,\n",
188
+ " device=DEVICE\n",
189
+ ").to(DEVICE).eval()"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 6,
195
+ "id": "f3ef2b4f",
196
+ "metadata": {},
197
+ "outputs": [
198
+ {
199
+ "data": {
200
+ "application/vnd.jupyter.widget-view+json": {
201
+ "model_id": "43131e0cdbdf44beb6f775f854ebbf07",
202
+ "version_major": 2,
203
+ "version_minor": 0
204
+ },
205
+ "text/plain": [
206
+ " 0%| | 0.00/107M [00:00<?, ?B/s]"
207
+ ]
208
+ },
209
+ "metadata": {},
210
+ "output_type": "display_data"
211
+ },
212
+ {
213
+ "data": {
214
+ "text/plain": [
215
+ "InceptionResnetV1(\n",
216
+ " (conv2d_1a): BasicConv2d(\n",
217
+ " (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
218
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
219
+ " (relu): ReLU()\n",
220
+ " )\n",
221
+ " (conv2d_2a): BasicConv2d(\n",
222
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)\n",
223
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
224
+ " (relu): ReLU()\n",
225
+ " )\n",
226
+ " (conv2d_2b): BasicConv2d(\n",
227
+ " (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
228
+ " (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
229
+ " (relu): ReLU()\n",
230
+ " )\n",
231
+ " (maxpool_3a): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
232
+ " (conv2d_3b): BasicConv2d(\n",
233
+ " (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
234
+ " (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
235
+ " (relu): ReLU()\n",
236
+ " )\n",
237
+ " (conv2d_4a): BasicConv2d(\n",
238
+ " (conv): Conv2d(80, 192, kernel_size=(3, 3), stride=(1, 1), bias=False)\n",
239
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
240
+ " (relu): ReLU()\n",
241
+ " )\n",
242
+ " (conv2d_4b): BasicConv2d(\n",
243
+ " (conv): Conv2d(192, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
244
+ " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
245
+ " (relu): ReLU()\n",
246
+ " )\n",
247
+ " (repeat_1): Sequential(\n",
248
+ " (0): Block35(\n",
249
+ " (branch0): BasicConv2d(\n",
250
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
251
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
252
+ " (relu): ReLU()\n",
253
+ " )\n",
254
+ " (branch1): Sequential(\n",
255
+ " (0): BasicConv2d(\n",
256
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
257
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
258
+ " (relu): ReLU()\n",
259
+ " )\n",
260
+ " (1): BasicConv2d(\n",
261
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
262
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
263
+ " (relu): ReLU()\n",
264
+ " )\n",
265
+ " )\n",
266
+ " (branch2): Sequential(\n",
267
+ " (0): BasicConv2d(\n",
268
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
269
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
270
+ " (relu): ReLU()\n",
271
+ " )\n",
272
+ " (1): BasicConv2d(\n",
273
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
274
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
275
+ " (relu): ReLU()\n",
276
+ " )\n",
277
+ " (2): BasicConv2d(\n",
278
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
279
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
280
+ " (relu): ReLU()\n",
281
+ " )\n",
282
+ " )\n",
283
+ " (conv2d): Conv2d(96, 256, kernel_size=(1, 1), stride=(1, 1))\n",
284
+ " (relu): ReLU()\n",
285
+ " )\n",
286
+ " (1): Block35(\n",
287
+ " (branch0): BasicConv2d(\n",
288
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
289
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
290
+ " (relu): ReLU()\n",
291
+ " )\n",
292
+ " (branch1): Sequential(\n",
293
+ " (0): BasicConv2d(\n",
294
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
295
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
296
+ " (relu): ReLU()\n",
297
+ " )\n",
298
+ " (1): BasicConv2d(\n",
299
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
300
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
301
+ " (relu): ReLU()\n",
302
+ " )\n",
303
+ " )\n",
304
+ " (branch2): Sequential(\n",
305
+ " (0): BasicConv2d(\n",
306
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
307
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
308
+ " (relu): ReLU()\n",
309
+ " )\n",
310
+ " (1): BasicConv2d(\n",
311
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
312
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
313
+ " (relu): ReLU()\n",
314
+ " )\n",
315
+ " (2): BasicConv2d(\n",
316
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
317
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
318
+ " (relu): ReLU()\n",
319
+ " )\n",
320
+ " )\n",
321
+ " (conv2d): Conv2d(96, 256, kernel_size=(1, 1), stride=(1, 1))\n",
322
+ " (relu): ReLU()\n",
323
+ " )\n",
324
+ " (2): Block35(\n",
325
+ " (branch0): BasicConv2d(\n",
326
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
327
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
328
+ " (relu): ReLU()\n",
329
+ " )\n",
330
+ " (branch1): Sequential(\n",
331
+ " (0): BasicConv2d(\n",
332
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
333
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
334
+ " (relu): ReLU()\n",
335
+ " )\n",
336
+ " (1): BasicConv2d(\n",
337
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
338
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
339
+ " (relu): ReLU()\n",
340
+ " )\n",
341
+ " )\n",
342
+ " (branch2): Sequential(\n",
343
+ " (0): BasicConv2d(\n",
344
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
345
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
346
+ " (relu): ReLU()\n",
347
+ " )\n",
348
+ " (1): BasicConv2d(\n",
349
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
350
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
351
+ " (relu): ReLU()\n",
352
+ " )\n",
353
+ " (2): BasicConv2d(\n",
354
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
355
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
356
+ " (relu): ReLU()\n",
357
+ " )\n",
358
+ " )\n",
359
+ " (conv2d): Conv2d(96, 256, kernel_size=(1, 1), stride=(1, 1))\n",
360
+ " (relu): ReLU()\n",
361
+ " )\n",
362
+ " (3): Block35(\n",
363
+ " (branch0): BasicConv2d(\n",
364
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
365
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
366
+ " (relu): ReLU()\n",
367
+ " )\n",
368
+ " (branch1): Sequential(\n",
369
+ " (0): BasicConv2d(\n",
370
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
371
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
372
+ " (relu): ReLU()\n",
373
+ " )\n",
374
+ " (1): BasicConv2d(\n",
375
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
376
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
377
+ " (relu): ReLU()\n",
378
+ " )\n",
379
+ " )\n",
380
+ " (branch2): Sequential(\n",
381
+ " (0): BasicConv2d(\n",
382
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
383
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
384
+ " (relu): ReLU()\n",
385
+ " )\n",
386
+ " (1): BasicConv2d(\n",
387
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
388
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
389
+ " (relu): ReLU()\n",
390
+ " )\n",
391
+ " (2): BasicConv2d(\n",
392
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
393
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
394
+ " (relu): ReLU()\n",
395
+ " )\n",
396
+ " )\n",
397
+ " (conv2d): Conv2d(96, 256, kernel_size=(1, 1), stride=(1, 1))\n",
398
+ " (relu): ReLU()\n",
399
+ " )\n",
400
+ " (4): Block35(\n",
401
+ " (branch0): BasicConv2d(\n",
402
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
403
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
404
+ " (relu): ReLU()\n",
405
+ " )\n",
406
+ " (branch1): Sequential(\n",
407
+ " (0): BasicConv2d(\n",
408
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
409
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
410
+ " (relu): ReLU()\n",
411
+ " )\n",
412
+ " (1): BasicConv2d(\n",
413
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
414
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
415
+ " (relu): ReLU()\n",
416
+ " )\n",
417
+ " )\n",
418
+ " (branch2): Sequential(\n",
419
+ " (0): BasicConv2d(\n",
420
+ " (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
421
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
422
+ " (relu): ReLU()\n",
423
+ " )\n",
424
+ " (1): BasicConv2d(\n",
425
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
426
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
427
+ " (relu): ReLU()\n",
428
+ " )\n",
429
+ " (2): BasicConv2d(\n",
430
+ " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
431
+ " (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
432
+ " (relu): ReLU()\n",
433
+ " )\n",
434
+ " )\n",
435
+ " (conv2d): Conv2d(96, 256, kernel_size=(1, 1), stride=(1, 1))\n",
436
+ " (relu): ReLU()\n",
437
+ " )\n",
438
+ " )\n",
439
+ " (mixed_6a): Mixed_6a(\n",
440
+ " (branch0): BasicConv2d(\n",
441
+ " (conv): Conv2d(256, 384, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
442
+ " (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
443
+ " (relu): ReLU()\n",
444
+ " )\n",
445
+ " (branch1): Sequential(\n",
446
+ " (0): BasicConv2d(\n",
447
+ " (conv): Conv2d(256, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
448
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
449
+ " (relu): ReLU()\n",
450
+ " )\n",
451
+ " (1): BasicConv2d(\n",
452
+ " (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
453
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
454
+ " (relu): ReLU()\n",
455
+ " )\n",
456
+ " (2): BasicConv2d(\n",
457
+ " (conv): Conv2d(192, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
458
+ " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
459
+ " (relu): ReLU()\n",
460
+ " )\n",
461
+ " )\n",
462
+ " (branch2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
463
+ " )\n",
464
+ " (repeat_2): Sequential(\n",
465
+ " (0): Block17(\n",
466
+ " (branch0): BasicConv2d(\n",
467
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
468
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
469
+ " (relu): ReLU()\n",
470
+ " )\n",
471
+ " (branch1): Sequential(\n",
472
+ " (0): BasicConv2d(\n",
473
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
474
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
475
+ " (relu): ReLU()\n",
476
+ " )\n",
477
+ " (1): BasicConv2d(\n",
478
+ " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n",
479
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
480
+ " (relu): ReLU()\n",
481
+ " )\n",
482
+ " (2): BasicConv2d(\n",
483
+ " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n",
484
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
485
+ " (relu): ReLU()\n",
486
+ " )\n",
487
+ " )\n",
488
+ " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n",
489
+ " (relu): ReLU()\n",
490
+ " )\n",
491
+ " (1): Block17(\n",
492
+ " (branch0): BasicConv2d(\n",
493
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
494
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
495
+ " (relu): ReLU()\n",
496
+ " )\n",
497
+ " (branch1): Sequential(\n",
498
+ " (0): BasicConv2d(\n",
499
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
500
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
501
+ " (relu): ReLU()\n",
502
+ " )\n",
503
+ " (1): BasicConv2d(\n",
504
+ " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n",
505
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
506
+ " (relu): ReLU()\n",
507
+ " )\n",
508
+ " (2): BasicConv2d(\n",
509
+ " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n",
510
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
511
+ " (relu): ReLU()\n",
512
+ " )\n",
513
+ " )\n",
514
+ " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n",
515
+ " (relu): ReLU()\n",
516
+ " )\n",
517
+ " (2): Block17(\n",
518
+ " (branch0): BasicConv2d(\n",
519
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
520
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
521
+ " (relu): ReLU()\n",
522
+ " )\n",
523
+ " (branch1): Sequential(\n",
524
+ " (0): BasicConv2d(\n",
525
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
526
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
527
+ " (relu): ReLU()\n",
528
+ " )\n",
529
+ " (1): BasicConv2d(\n",
530
+ " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n",
531
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
532
+ " (relu): ReLU()\n",
533
+ " )\n",
534
+ " (2): BasicConv2d(\n",
535
+ " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n",
536
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
537
+ " (relu): ReLU()\n",
538
+ " )\n",
539
+ " )\n",
540
+ " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n",
541
+ " (relu): ReLU()\n",
542
+ " )\n",
543
+ " (3): Block17(\n",
544
+ " (branch0): BasicConv2d(\n",
545
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
546
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
547
+ " (relu): ReLU()\n",
548
+ " )\n",
549
+ " (branch1): Sequential(\n",
550
+ " (0): BasicConv2d(\n",
551
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
552
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
553
+ " (relu): ReLU()\n",
554
+ " )\n",
555
+ " (1): BasicConv2d(\n",
556
+ " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n",
557
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
558
+ " (relu): ReLU()\n",
559
+ " )\n",
560
+ " (2): BasicConv2d(\n",
561
+ " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n",
562
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
563
+ " (relu): ReLU()\n",
564
+ " )\n",
565
+ " )\n",
566
+ " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n",
567
+ " (relu): ReLU()\n",
568
+ " )\n",
569
+ " (4): Block17(\n",
570
+ " (branch0): BasicConv2d(\n",
571
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
572
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
573
+ " (relu): ReLU()\n",
574
+ " )\n",
575
+ " (branch1): Sequential(\n",
576
+ " (0): BasicConv2d(\n",
577
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
578
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
579
+ " (relu): ReLU()\n",
580
+ " )\n",
581
+ " (1): BasicConv2d(\n",
582
+ " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n",
583
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
584
+ " (relu): ReLU()\n",
585
+ " )\n",
586
+ " (2): BasicConv2d(\n",
587
+ " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n",
588
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
589
+ " (relu): ReLU()\n",
590
+ " )\n",
591
+ " )\n",
592
+ " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n",
593
+ " (relu): ReLU()\n",
594
+ " )\n",
595
+ " (5): Block17(\n",
596
+ " (branch0): BasicConv2d(\n",
597
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
598
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
599
+ " (relu): ReLU()\n",
600
+ " )\n",
601
+ " (branch1): Sequential(\n",
602
+ " (0): BasicConv2d(\n",
603
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
604
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
605
+ " (relu): ReLU()\n",
606
+ " )\n",
607
+ " (1): BasicConv2d(\n",
608
+ " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n",
609
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
610
+ " (relu): ReLU()\n",
611
+ " )\n",
612
+ " (2): BasicConv2d(\n",
613
+ " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n",
614
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
615
+ " (relu): ReLU()\n",
616
+ " )\n",
617
+ " )\n",
618
+ " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n",
619
+ " (relu): ReLU()\n",
620
+ " )\n",
621
+ " (6): Block17(\n",
622
+ " (branch0): BasicConv2d(\n",
623
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
624
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
625
+ " (relu): ReLU()\n",
626
+ " )\n",
627
+ " (branch1): Sequential(\n",
628
+ " (0): BasicConv2d(\n",
629
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
630
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
631
+ " (relu): ReLU()\n",
632
+ " )\n",
633
+ " (1): BasicConv2d(\n",
634
+ " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n",
635
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
636
+ " (relu): ReLU()\n",
637
+ " )\n",
638
+ " (2): BasicConv2d(\n",
639
+ " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n",
640
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
641
+ " (relu): ReLU()\n",
642
+ " )\n",
643
+ " )\n",
644
+ " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n",
645
+ " (relu): ReLU()\n",
646
+ " )\n",
647
+ " (7): Block17(\n",
648
+ " (branch0): BasicConv2d(\n",
649
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
650
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
651
+ " (relu): ReLU()\n",
652
+ " )\n",
653
+ " (branch1): Sequential(\n",
654
+ " (0): BasicConv2d(\n",
655
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
656
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
657
+ " (relu): ReLU()\n",
658
+ " )\n",
659
+ " (1): BasicConv2d(\n",
660
+ " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n",
661
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
662
+ " (relu): ReLU()\n",
663
+ " )\n",
664
+ " (2): BasicConv2d(\n",
665
+ " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n",
666
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
667
+ " (relu): ReLU()\n",
668
+ " )\n",
669
+ " )\n",
670
+ " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n",
671
+ " (relu): ReLU()\n",
672
+ " )\n",
673
+ " (8): Block17(\n",
674
+ " (branch0): BasicConv2d(\n",
675
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
676
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
677
+ " (relu): ReLU()\n",
678
+ " )\n",
679
+ " (branch1): Sequential(\n",
680
+ " (0): BasicConv2d(\n",
681
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
682
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
683
+ " (relu): ReLU()\n",
684
+ " )\n",
685
+ " (1): BasicConv2d(\n",
686
+ " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n",
687
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
688
+ " (relu): ReLU()\n",
689
+ " )\n",
690
+ " (2): BasicConv2d(\n",
691
+ " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n",
692
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
693
+ " (relu): ReLU()\n",
694
+ " )\n",
695
+ " )\n",
696
+ " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n",
697
+ " (relu): ReLU()\n",
698
+ " )\n",
699
+ " (9): Block17(\n",
700
+ " (branch0): BasicConv2d(\n",
701
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
702
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
703
+ " (relu): ReLU()\n",
704
+ " )\n",
705
+ " (branch1): Sequential(\n",
706
+ " (0): BasicConv2d(\n",
707
+ " (conv): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
708
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
709
+ " (relu): ReLU()\n",
710
+ " )\n",
711
+ " (1): BasicConv2d(\n",
712
+ " (conv): Conv2d(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 3), bias=False)\n",
713
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
714
+ " (relu): ReLU()\n",
715
+ " )\n",
716
+ " (2): BasicConv2d(\n",
717
+ " (conv): Conv2d(128, 128, kernel_size=(7, 1), stride=(1, 1), padding=(3, 0), bias=False)\n",
718
+ " (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
719
+ " (relu): ReLU()\n",
720
+ " )\n",
721
+ " )\n",
722
+ " (conv2d): Conv2d(256, 896, kernel_size=(1, 1), stride=(1, 1))\n",
723
+ " (relu): ReLU()\n",
724
+ " )\n",
725
+ " )\n",
726
+ " (mixed_7a): Mixed_7a(\n",
727
+ " (branch0): Sequential(\n",
728
+ " (0): BasicConv2d(\n",
729
+ " (conv): Conv2d(896, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
730
+ " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
731
+ " (relu): ReLU()\n",
732
+ " )\n",
733
+ " (1): BasicConv2d(\n",
734
+ " (conv): Conv2d(256, 384, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
735
+ " (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
736
+ " (relu): ReLU()\n",
737
+ " )\n",
738
+ " )\n",
739
+ " (branch1): Sequential(\n",
740
+ " (0): BasicConv2d(\n",
741
+ " (conv): Conv2d(896, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
742
+ " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
743
+ " (relu): ReLU()\n",
744
+ " )\n",
745
+ " (1): BasicConv2d(\n",
746
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
747
+ " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
748
+ " (relu): ReLU()\n",
749
+ " )\n",
750
+ " )\n",
751
+ " (branch2): Sequential(\n",
752
+ " (0): BasicConv2d(\n",
753
+ " (conv): Conv2d(896, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
754
+ " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
755
+ " (relu): ReLU()\n",
756
+ " )\n",
757
+ " (1): BasicConv2d(\n",
758
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
759
+ " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
760
+ " (relu): ReLU()\n",
761
+ " )\n",
762
+ " (2): BasicConv2d(\n",
763
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
764
+ " (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
765
+ " (relu): ReLU()\n",
766
+ " )\n",
767
+ " )\n",
768
+ " (branch3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
769
+ " )\n",
770
+ " (repeat_3): Sequential(\n",
771
+ " (0): Block8(\n",
772
+ " (branch0): BasicConv2d(\n",
773
+ " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
774
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
775
+ " (relu): ReLU()\n",
776
+ " )\n",
777
+ " (branch1): Sequential(\n",
778
+ " (0): BasicConv2d(\n",
779
+ " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
780
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
781
+ " (relu): ReLU()\n",
782
+ " )\n",
783
+ " (1): BasicConv2d(\n",
784
+ " (conv): Conv2d(192, 192, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)\n",
785
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
786
+ " (relu): ReLU()\n",
787
+ " )\n",
788
+ " (2): BasicConv2d(\n",
789
+ " (conv): Conv2d(192, 192, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)\n",
790
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
791
+ " (relu): ReLU()\n",
792
+ " )\n",
793
+ " )\n",
794
+ " (conv2d): Conv2d(384, 1792, kernel_size=(1, 1), stride=(1, 1))\n",
795
+ " (relu): ReLU()\n",
796
+ " )\n",
797
+ " (1): Block8(\n",
798
+ " (branch0): BasicConv2d(\n",
799
+ " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
800
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
801
+ " (relu): ReLU()\n",
802
+ " )\n",
803
+ " (branch1): Sequential(\n",
804
+ " (0): BasicConv2d(\n",
805
+ " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
806
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
807
+ " (relu): ReLU()\n",
808
+ " )\n",
809
+ " (1): BasicConv2d(\n",
810
+ " (conv): Conv2d(192, 192, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)\n",
811
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
812
+ " (relu): ReLU()\n",
813
+ " )\n",
814
+ " (2): BasicConv2d(\n",
815
+ " (conv): Conv2d(192, 192, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)\n",
816
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
817
+ " (relu): ReLU()\n",
818
+ " )\n",
819
+ " )\n",
820
+ " (conv2d): Conv2d(384, 1792, kernel_size=(1, 1), stride=(1, 1))\n",
821
+ " (relu): ReLU()\n",
822
+ " )\n",
823
+ " (2): Block8(\n",
824
+ " (branch0): BasicConv2d(\n",
825
+ " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
826
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
827
+ " (relu): ReLU()\n",
828
+ " )\n",
829
+ " (branch1): Sequential(\n",
830
+ " (0): BasicConv2d(\n",
831
+ " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
832
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
833
+ " (relu): ReLU()\n",
834
+ " )\n",
835
+ " (1): BasicConv2d(\n",
836
+ " (conv): Conv2d(192, 192, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)\n",
837
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
838
+ " (relu): ReLU()\n",
839
+ " )\n",
840
+ " (2): BasicConv2d(\n",
841
+ " (conv): Conv2d(192, 192, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)\n",
842
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
843
+ " (relu): ReLU()\n",
844
+ " )\n",
845
+ " )\n",
846
+ " (conv2d): Conv2d(384, 1792, kernel_size=(1, 1), stride=(1, 1))\n",
847
+ " (relu): ReLU()\n",
848
+ " )\n",
849
+ " (3): Block8(\n",
850
+ " (branch0): BasicConv2d(\n",
851
+ " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
852
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
853
+ " (relu): ReLU()\n",
854
+ " )\n",
855
+ " (branch1): Sequential(\n",
856
+ " (0): BasicConv2d(\n",
857
+ " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
858
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
859
+ " (relu): ReLU()\n",
860
+ " )\n",
861
+ " (1): BasicConv2d(\n",
862
+ " (conv): Conv2d(192, 192, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)\n",
863
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
864
+ " (relu): ReLU()\n",
865
+ " )\n",
866
+ " (2): BasicConv2d(\n",
867
+ " (conv): Conv2d(192, 192, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)\n",
868
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
869
+ " (relu): ReLU()\n",
870
+ " )\n",
871
+ " )\n",
872
+ " (conv2d): Conv2d(384, 1792, kernel_size=(1, 1), stride=(1, 1))\n",
873
+ " (relu): ReLU()\n",
874
+ " )\n",
875
+ " (4): Block8(\n",
876
+ " (branch0): BasicConv2d(\n",
877
+ " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
878
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
879
+ " (relu): ReLU()\n",
880
+ " )\n",
881
+ " (branch1): Sequential(\n",
882
+ " (0): BasicConv2d(\n",
883
+ " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
884
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
885
+ " (relu): ReLU()\n",
886
+ " )\n",
887
+ " (1): BasicConv2d(\n",
888
+ " (conv): Conv2d(192, 192, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)\n",
889
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
890
+ " (relu): ReLU()\n",
891
+ " )\n",
892
+ " (2): BasicConv2d(\n",
893
+ " (conv): Conv2d(192, 192, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)\n",
894
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
895
+ " (relu): ReLU()\n",
896
+ " )\n",
897
+ " )\n",
898
+ " (conv2d): Conv2d(384, 1792, kernel_size=(1, 1), stride=(1, 1))\n",
899
+ " (relu): ReLU()\n",
900
+ " )\n",
901
+ " )\n",
902
+ " (block8): Block8(\n",
903
+ " (branch0): BasicConv2d(\n",
904
+ " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
905
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
906
+ " (relu): ReLU()\n",
907
+ " )\n",
908
+ " (branch1): Sequential(\n",
909
+ " (0): BasicConv2d(\n",
910
+ " (conv): Conv2d(1792, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
911
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
912
+ " (relu): ReLU()\n",
913
+ " )\n",
914
+ " (1): BasicConv2d(\n",
915
+ " (conv): Conv2d(192, 192, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), bias=False)\n",
916
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
917
+ " (relu): ReLU()\n",
918
+ " )\n",
919
+ " (2): BasicConv2d(\n",
920
+ " (conv): Conv2d(192, 192, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), bias=False)\n",
921
+ " (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
922
+ " (relu): ReLU()\n",
923
+ " )\n",
924
+ " )\n",
925
+ " (conv2d): Conv2d(384, 1792, kernel_size=(1, 1), stride=(1, 1))\n",
926
+ " )\n",
927
+ " (avgpool_1a): AdaptiveAvgPool2d(output_size=1)\n",
928
+ " (dropout): Dropout(p=0.6, inplace=False)\n",
929
+ " (last_linear): Linear(in_features=1792, out_features=512, bias=False)\n",
930
+ " (last_bn): BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n",
931
+ " (logits): Linear(in_features=512, out_features=1, bias=True)\n",
932
+ ")"
933
+ ]
934
+ },
935
+ "execution_count": 6,
936
+ "metadata": {},
937
+ "output_type": "execute_result"
938
+ }
939
+ ],
940
+ "source": [
941
+ "model = InceptionResnetV1(\n",
942
+ " pretrained=\"vggface2\",\n",
943
+ " classify=True,\n",
944
+ " num_classes=1,\n",
945
+ " device=DEVICE\n",
946
+ ")\n",
947
+ "\n",
948
+ "checkpoint = torch.load(\"resnetinceptionv1_epoch_32.pth\", map_location=torch.device('cpu'))\n",
949
+ "model.load_state_dict(checkpoint['model_state_dict'])\n",
950
+ "model.to(DEVICE)\n",
951
+ "model.eval()"
952
+ ]
953
+ },
954
+ {
955
+ "cell_type": "markdown",
956
+ "id": "a499194a",
957
+ "metadata": {},
958
+ "source": [
959
+ "# Model Inference "
960
+ ]
961
+ },
962
+ {
963
+ "cell_type": "code",
964
+ "execution_count": 8,
965
+ "id": "376e6cd6",
966
+ "metadata": {},
967
+ "outputs": [],
968
+ "source": [
969
+ "def predict(input_image:Image.Image):\n",
970
+ " \"\"\"Predict the label of the input_image\"\"\"\n",
971
+ " face = mtcnn(input_image)\n",
972
+ " if face is None:\n",
973
+ " raise Exception('No face detected')\n",
974
+ " face = face.unsqueeze(0) # add the batch dimension\n",
975
+ " face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)\n",
976
+ " \n",
977
+ " # convert the face into a numpy array to be able to plot it\n",
978
+ " prev_face = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()\n",
979
+ " prev_face = prev_face.astype('uint8')\n",
980
+ "\n",
981
+ " face = face.to(DEVICE)\n",
982
+ " face = face.to(torch.float32)\n",
983
+ " face = face / 255.0\n",
984
+ " face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()\n",
985
+ "\n",
986
+ " target_layers=[model.block8.branch1[-1]]\n",
987
+ " use_cuda = True if torch.cuda.is_available() else False\n",
988
+ " cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)\n",
989
+ " targets = [ClassifierOutputTarget(0)]\n",
990
+ "\n",
991
+ " grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)\n",
992
+ " grayscale_cam = grayscale_cam[0, :]\n",
993
+ " visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)\n",
994
+ " face_with_mask = cv2.addWeighted(prev_face, 1, visualization, 0.5, 0)\n",
995
+ "\n",
996
+ " with torch.no_grad():\n",
997
+ " output = torch.sigmoid(model(face).squeeze(0))\n",
998
+ " prediction = \"real\" if output.item() < 0.5 else \"fake\"\n",
999
+ " \n",
1000
+ " real_prediction = 1 - output.item()\n",
1001
+ " fake_prediction = output.item()\n",
1002
+ " \n",
1003
+ " confidences = {\n",
1004
+ " 'real': real_prediction,\n",
1005
+ " 'fake': fake_prediction\n",
1006
+ " }\n",
1007
+ " return confidences, face_with_mask\n"
1008
+ ]
1009
+ },
1010
+ {
1011
+ "cell_type": "markdown",
1012
+ "id": "14f47b5a",
1013
+ "metadata": {},
1014
+ "source": [
1015
+ "# Gradio Interface"
1016
+ ]
1017
+ },
1018
+ {
1019
+ "cell_type": "code",
1020
+ "execution_count": 9,
1021
+ "id": "d62177b5",
1022
+ "metadata": {},
1023
+ "outputs": [
1024
+ {
1025
+ "name": "stdout",
1026
+ "output_type": "stream",
1027
+ "text": [
1028
+ "Running on local URL: http://127.0.0.1:7860\n",
1029
+ "\n",
1030
+ "To create a public link, set `share=True` in `launch()`.\n"
1031
+ ]
1032
+ },
1033
+ {
1034
+ "data": {
1035
+ "text/html": [
1036
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
1037
+ ],
1038
+ "text/plain": [
1039
+ "<IPython.core.display.HTML object>"
1040
+ ]
1041
+ },
1042
+ "metadata": {},
1043
+ "output_type": "display_data"
1044
+ }
1045
+ ],
1046
+ "source": [
1047
+ "interface = gr.Interface(\n",
1048
+ " fn=predict,\n",
1049
+ " inputs=[\n",
1050
+ " gr.inputs.Image(label=\"Input Image\", type=\"pil\")\n",
1051
+ " ],\n",
1052
+ " outputs=[\n",
1053
+ " gr.outputs.Label(label=\"Class\"),\n",
1054
+ " gr.outputs.Image(label=\"Face with Explainability\", type=\"pil\")\n",
1055
+ " ],\n",
1056
+ ").launch()"
1057
+ ]
1058
+ },
1059
+ {
1060
+ "cell_type": "code",
1061
+ "execution_count": null,
1062
+ "id": "0c0b293c",
1063
+ "metadata": {},
1064
+ "outputs": [],
1065
+ "source": []
1066
+ }
1067
+ ],
1068
+ "metadata": {
1069
+ "kernelspec": {
1070
+ "display_name": "Python 3 (ipykernel)",
1071
+ "language": "python",
1072
+ "name": "python3"
1073
+ },
1074
+ "language_info": {
1075
+ "codemirror_mode": {
1076
+ "name": "ipython",
1077
+ "version": 3
1078
+ },
1079
+ "file_extension": ".py",
1080
+ "mimetype": "text/x-python",
1081
+ "name": "python",
1082
+ "nbconvert_exporter": "python",
1083
+ "pygments_lexer": "ipython3",
1084
+ "version": "3.9.8"
1085
+ }
1086
+ },
1087
+ "nbformat": 4,
1088
+ "nbformat_minor": 5
1089
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ jupyter==1.0.0
2
+ gradio==3.23.0
3
+ Pillow==9.4.0
4
+ facenet-pytorch==2.5.2
5
+ torch==1.11.0
6
+ opencv-python==4.7.0.72
7
+ grad-cam==1.4.6