Gabriel Edradan commited on
Commit
e0f9876
Β·
verified Β·
1 Parent(s): 57d515b

Delete hmc_grad.ipynb

Browse files
Files changed (1) hide show
  1. hmc_grad.ipynb +0 -1032
hmc_grad.ipynb DELETED
@@ -1,1032 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "9z2oy94RRxJa"
7
- },
8
- "source": [
9
- "# HMC-GRAD: HANDWRITTEN MULTIPLE-CHOICE TEST GRADER\n",
10
- "The implementation of HMC-Grad, employing OpenCV for image preprocessing and PyTorch for training a convolutional neural network on the EMNIST dataset for image classification."
11
- ]
12
- },
13
- {
14
- "cell_type": "markdown",
15
- "metadata": {
16
- "id": "EaklWdMZRwYU"
17
- },
18
- "source": [
19
- "## IMPORT MODULES AND SET GLOBALS"
20
- ]
21
- },
22
- {
23
- "cell_type": "code",
24
- "execution_count": 1,
25
- "metadata": {
26
- "colab": {
27
- "base_uri": "https://localhost:8080/"
28
- },
29
- "id": "d5Zz7NinfDDM",
30
- "outputId": "f6555720-b483-4cf8-8d95-0ab530d02974"
31
- },
32
- "outputs": [
33
- {
34
- "output_type": "stream",
35
- "name": "stdout",
36
- "text": [
37
- "Collecting gradio\n",
38
- " Downloading gradio-4.14.0-py3-none-any.whl (16.6 MB)\n",
39
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.6/16.6 MB\u001b[0m \u001b[31m49.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
40
- "\u001b[?25hCollecting aiofiles<24.0,>=22.0 (from gradio)\n",
41
- " Downloading aiofiles-23.2.1-py3-none-any.whl (15 kB)\n",
42
- "Requirement already satisfied: altair<6.0,>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.2.2)\n",
43
- "Collecting fastapi (from gradio)\n",
44
- " Downloading fastapi-0.109.0-py3-none-any.whl (92 kB)\n",
45
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.0/92.0 kB\u001b[0m \u001b[31m13.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
46
- "\u001b[?25hCollecting ffmpy (from gradio)\n",
47
- " Downloading ffmpy-0.3.1.tar.gz (5.5 kB)\n",
48
- " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
49
- "Collecting gradio-client==0.8.0 (from gradio)\n",
50
- " Downloading gradio_client-0.8.0-py3-none-any.whl (305 kB)\n",
51
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m305.1/305.1 kB\u001b[0m \u001b[31m33.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
52
- "\u001b[?25hCollecting httpx (from gradio)\n",
53
- " Downloading httpx-0.26.0-py3-none-any.whl (75 kB)\n",
54
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.9/75.9 kB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
55
- "\u001b[?25hRequirement already satisfied: huggingface-hub>=0.19.3 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.20.2)\n",
56
- "Requirement already satisfied: importlib-resources<7.0,>=1.3 in /usr/local/lib/python3.10/dist-packages (from gradio) (6.1.1)\n",
57
- "Requirement already satisfied: jinja2<4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.1.2)\n",
58
- "Requirement already satisfied: markupsafe~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.1.3)\n",
59
- "Requirement already satisfied: matplotlib~=3.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.7.1)\n",
60
- "Requirement already satisfied: numpy~=1.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.23.5)\n",
61
- "Collecting orjson~=3.0 (from gradio)\n",
62
- " Downloading orjson-3.9.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (138 kB)\n",
63
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.7/138.7 kB\u001b[0m \u001b[31m19.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
64
- "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from gradio) (23.2)\n",
65
- "Requirement already satisfied: pandas<3.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.5.3)\n",
66
- "Requirement already satisfied: pillow<11.0,>=8.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (9.4.0)\n",
67
- "Collecting pydantic>=2.0 (from gradio)\n",
68
- " Downloading pydantic-2.5.3-py3-none-any.whl (381 kB)\n",
69
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m381.9/381.9 kB\u001b[0m \u001b[31m40.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
70
- "\u001b[?25hCollecting pydub (from gradio)\n",
71
- " Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)\n",
72
- "Collecting python-multipart (from gradio)\n",
73
- " Downloading python_multipart-0.0.6-py3-none-any.whl (45 kB)\n",
74
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.7/45.7 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
75
- "\u001b[?25hRequirement already satisfied: pyyaml<7.0,>=5.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (6.0.1)\n",
76
- "Collecting semantic-version~=2.0 (from gradio)\n",
77
- " Downloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)\n",
78
- "Collecting tomlkit==0.12.0 (from gradio)\n",
79
- " Downloading tomlkit-0.12.0-py3-none-any.whl (37 kB)\n",
80
- "Requirement already satisfied: typer[all]<1.0,>=0.9 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.9.0)\n",
81
- "Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.5.0)\n",
82
- "Collecting uvicorn>=0.14.0 (from gradio)\n",
83
- " Downloading uvicorn-0.25.0-py3-none-any.whl (60 kB)\n",
84
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.3/60.3 kB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
85
- "\u001b[?25hRequirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from gradio-client==0.8.0->gradio) (2023.6.0)\n",
86
- "Collecting websockets<12.0,>=10.0 (from gradio-client==0.8.0->gradio)\n",
87
- " Downloading websockets-11.0.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB)\n",
88
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m129.9/129.9 kB\u001b[0m \u001b[31m20.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
89
- "\u001b[?25hRequirement already satisfied: entrypoints in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (0.4)\n",
90
- "Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (4.19.2)\n",
91
- "Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (0.12.0)\n",
92
- "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.19.3->gradio) (3.13.1)\n",
93
- "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.19.3->gradio) (2.31.0)\n",
94
- "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.19.3->gradio) (4.66.1)\n",
95
- "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (1.2.0)\n",
96
- "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (0.12.1)\n",
97
- "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (4.47.0)\n",
98
- "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (1.4.5)\n",
99
- "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (3.1.1)\n",
100
- "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (2.8.2)\n",
101
- "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas<3.0,>=1.0->gradio) (2023.3.post1)\n",
102
- "Collecting annotated-types>=0.4.0 (from pydantic>=2.0->gradio)\n",
103
- " Downloading annotated_types-0.6.0-py3-none-any.whl (12 kB)\n",
104
- "Collecting pydantic-core==2.14.6 (from pydantic>=2.0->gradio)\n",
105
- " Downloading pydantic_core-2.14.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)\n",
106
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m90.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
107
- "\u001b[?25hCollecting typing-extensions~=4.0 (from gradio)\n",
108
- " Downloading typing_extensions-4.9.0-py3-none-any.whl (32 kB)\n",
109
- "Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.10/dist-packages (from typer[all]<1.0,>=0.9->gradio) (8.1.7)\n",
110
- "Collecting colorama<0.5.0,>=0.4.3 (from typer[all]<1.0,>=0.9->gradio)\n",
111
- " Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)\n",
112
- "Collecting shellingham<2.0.0,>=1.3.0 (from typer[all]<1.0,>=0.9->gradio)\n",
113
- " Downloading shellingham-1.5.4-py2.py3-none-any.whl (9.8 kB)\n",
114
- "Requirement already satisfied: rich<14.0.0,>=10.11.0 in /usr/local/lib/python3.10/dist-packages (from typer[all]<1.0,>=0.9->gradio) (13.7.0)\n",
115
- "Collecting h11>=0.8 (from uvicorn>=0.14.0->gradio)\n",
116
- " Downloading h11-0.14.0-py3-none-any.whl (58 kB)\n",
117
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
118
- "\u001b[?25hCollecting starlette<0.36.0,>=0.35.0 (from fastapi->gradio)\n",
119
- " Downloading starlette-0.35.1-py3-none-any.whl (71 kB)\n",
120
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.1/71.1 kB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
121
- "\u001b[?25hRequirement already satisfied: anyio in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (3.7.1)\n",
122
- "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (2023.11.17)\n",
123
- "Collecting httpcore==1.* (from httpx->gradio)\n",
124
- " Downloading httpcore-1.0.2-py3-none-any.whl (76 kB)\n",
125
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.9/76.9 kB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
126
- "\u001b[?25hRequirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (3.6)\n",
127
- "Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (1.3.0)\n",
128
- "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (23.2.0)\n",
129
- "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (2023.12.1)\n",
130
- "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.32.1)\n",
131
- "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.16.2)\n",
132
- "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio) (1.16.0)\n",
133
- "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich<14.0.0,>=10.11.0->typer[all]<1.0,>=0.9->gradio) (3.0.0)\n",
134
- "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich<14.0.0,>=10.11.0->typer[all]<1.0,>=0.9->gradio) (2.16.1)\n",
135
- "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio->httpx->gradio) (1.2.0)\n",
136
- "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.19.3->gradio) (3.3.2)\n",
137
- "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.19.3->gradio) (2.0.7)\n",
138
- "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich<14.0.0,>=10.11.0->typer[all]<1.0,>=0.9->gradio) (0.1.2)\n",
139
- "Building wheels for collected packages: ffmpy\n",
140
- " Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
141
- " Created wheel for ffmpy: filename=ffmpy-0.3.1-py3-none-any.whl size=5579 sha256=6bf054888300bb0e901364bd78e5f6024f5fafc6636ec3a1c973b9b522ecd753\n",
142
- " Stored in directory: /root/.cache/pip/wheels/01/a6/d1/1c0828c304a4283b2c1639a09ad86f83d7c487ef34c6b4a1bf\n",
143
- "Successfully built ffmpy\n",
144
- "Installing collected packages: pydub, ffmpy, websockets, typing-extensions, tomlkit, shellingham, semantic-version, python-multipart, orjson, h11, colorama, annotated-types, aiofiles, uvicorn, starlette, pydantic-core, httpcore, pydantic, httpx, gradio-client, fastapi, gradio\n",
145
- " Attempting uninstall: typing-extensions\n",
146
- " Found existing installation: typing_extensions 4.5.0\n",
147
- " Uninstalling typing_extensions-4.5.0:\n",
148
- " Successfully uninstalled typing_extensions-4.5.0\n",
149
- " Attempting uninstall: pydantic\n",
150
- " Found existing installation: pydantic 1.10.13\n",
151
- " Uninstalling pydantic-1.10.13:\n",
152
- " Successfully uninstalled pydantic-1.10.13\n",
153
- "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
154
- "lida 0.0.10 requires kaleido, which is not installed.\n",
155
- "llmx 0.0.15a0 requires cohere, which is not installed.\n",
156
- "llmx 0.0.15a0 requires openai, which is not installed.\n",
157
- "llmx 0.0.15a0 requires tiktoken, which is not installed.\n",
158
- "tensorflow-probability 0.22.0 requires typing-extensions<4.6.0, but you have typing-extensions 4.9.0 which is incompatible.\u001b[0m\u001b[31m\n",
159
- "\u001b[0mSuccessfully installed aiofiles-23.2.1 annotated-types-0.6.0 colorama-0.4.6 fastapi-0.109.0 ffmpy-0.3.1 gradio-4.14.0 gradio-client-0.8.0 h11-0.14.0 httpcore-1.0.2 httpx-0.26.0 orjson-3.9.10 pydantic-2.5.3 pydantic-core-2.14.6 pydub-0.25.1 python-multipart-0.0.6 semantic-version-2.10.0 shellingham-1.5.4 starlette-0.35.1 tomlkit-0.12.0 typing-extensions-4.9.0 uvicorn-0.25.0 websockets-11.0.3\n"
160
- ]
161
- }
162
- ],
163
- "source": [
164
- "!pip install gradio"
165
- ]
166
- },
167
- {
168
- "cell_type": "code",
169
- "execution_count": 2,
170
- "metadata": {
171
- "id": "_4U_dUaQs2i-"
172
- },
173
- "outputs": [],
174
- "source": [
175
- "# For downloading model weights\n",
176
- "import os\n",
177
- "\n",
178
- "# For image preprocessing\n",
179
- "import cv2\n",
180
- "import numpy as np\n",
181
- "from sklearn.cluster import KMeans\n",
182
- "\n",
183
- "# For the image classifier\n",
184
- "import torch\n",
185
- "import torch.nn as nn\n",
186
- "import torch.nn.functional as F\n",
187
- "\n",
188
- "# For the interface\n",
189
- "import gradio as gr\n",
190
- "\n",
191
- "# For formatting data\n",
192
- "import pandas as pd"
193
- ]
194
- },
195
- {
196
- "cell_type": "code",
197
- "execution_count": 3,
198
- "metadata": {
199
- "id": "RrOzL7g0zab_"
200
- },
201
- "outputs": [],
202
- "source": [
203
- "# Globals Constants:\n",
204
- "\n",
205
- "# For preprocessing\n",
206
- "IMAGE_WIDTH = 1125\n",
207
- "IMAGE_HEIGHT = 1500\n",
208
- "THRESH_BLOCK_SIZE = 11\n",
209
- "THRESH_CONSTANT = 5\n",
210
- "LINE_LENGTH = 5000\n",
211
- "\n",
212
- "# For ROI extraction\n",
213
- "MAX_COMPONENT_MERGE_DISTANCE = 30\n",
214
- "MIN_COMPONENT_SIDE = 15\n",
215
- "Y_EPSILON = 25\n",
216
- "NUMBER_OF_COLUMNS = 2\n",
217
- "\n",
218
- "# For image classification\n",
219
- "PREDICTION_TO_STRING = [\"A\", \"B\", \"C\", \"D\"]"
220
- ]
221
- },
222
- {
223
- "cell_type": "markdown",
224
- "metadata": {
225
- "id": "TFv3JBjAOy08"
226
- },
227
- "source": [
228
- "## THE CLASSIFIER\n",
229
- "Prepare the Classifier Model."
230
- ]
231
- },
232
- {
233
- "cell_type": "code",
234
- "execution_count": 21,
235
- "metadata": {
236
- "id": "yK9WDGQQoKOA"
237
- },
238
- "outputs": [],
239
- "source": [
240
- "# Download the model's state dictionary from repository\n",
241
- "GITHUB_URL = \"https://raw.githubusercontent.com/GabrielEdradan/HMC-Grad/main/image_classifier.pth\"\n",
242
- "MODEL_PATH = \"/content/model.pth\"\n",
243
- "os.system(f\"wget {GITHUB_URL} -O {MODEL_PATH}\")\n",
244
- "model_state_dict = torch.load(MODEL_PATH)"
245
- ]
246
- },
247
- {
248
- "cell_type": "code",
249
- "execution_count": 22,
250
- "metadata": {
251
- "id": "avmc5ro9s1s1"
252
- },
253
- "outputs": [],
254
- "source": [
255
- "# Define the model class\n",
256
- "class ABCDClassifier(nn.Module):\n",
257
- "\n",
258
- " def __init__(self):\n",
259
- " super(ABCDClassifier, self).__init__()\n",
260
- "\n",
261
- " self.conv1 = nn.Conv2d(1, 16, kernel_size=5)\n",
262
- " self.conv2 = nn.Conv2d(16, 32, kernel_size=5)\n",
263
- " self.conv2_drop = nn.Dropout2d()\n",
264
- " self.fc1 = nn.Linear(32 * 4 * 4, 128)\n",
265
- " self.fc2 = nn.Linear(128, 4)\n",
266
- "\n",
267
- " def forward(self, x):\n",
268
- " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
269
- " x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
270
- " x = x.view(-1, 32 * 4 * 4)\n",
271
- " x = F.relu(self.fc1(x))\n",
272
- " x = F.dropout(x, training=self.training)\n",
273
- " x = self.fc2(x)\n",
274
- "\n",
275
- " return F.softmax(x, dim=1)"
276
- ]
277
- },
278
- {
279
- "cell_type": "code",
280
- "execution_count": 23,
281
- "metadata": {
282
- "colab": {
283
- "base_uri": "https://localhost:8080/"
284
- },
285
- "id": "o3XUKncaOrDJ",
286
- "outputId": "5b7499b4-fc73-4cf5-f3eb-8b955cf51a96"
287
- },
288
- "outputs": [
289
- {
290
- "output_type": "execute_result",
291
- "data": {
292
- "text/plain": [
293
- "<All keys matched successfully>"
294
- ]
295
- },
296
- "metadata": {},
297
- "execution_count": 23
298
- }
299
- ],
300
- "source": [
301
- "# Load the model\n",
302
- "model = ABCDClassifier()\n",
303
- "model.load_state_dict(model_state_dict)"
304
- ]
305
- },
306
- {
307
- "cell_type": "markdown",
308
- "metadata": {
309
- "id": "qsWpajjjOtmT"
310
- },
311
- "source": [
312
- "## THE IMAGE PROCESSOR"
313
- ]
314
- },
315
- {
316
- "cell_type": "markdown",
317
- "metadata": {
318
- "id": "7Q5HWgmSQx4t"
319
- },
320
- "source": [
321
- "### The Image Processor Functions"
322
- ]
323
- },
324
- {
325
- "cell_type": "code",
326
- "execution_count": 24,
327
- "metadata": {
328
- "id": "UaVWpmkRtXFN"
329
- },
330
- "outputs": [],
331
- "source": [
332
- "# The main function used for the interface.\n",
333
- "# Takes in an array of strings (image paths), and an array of chars (correction key)\n",
334
- "# Returns four str file paths to csv files: correctness, scores, item analysis, and score analysis\n",
335
- "def process_image_set(image_paths_array, correction_key):\n",
336
- " correctness = [] # an array of binary int arrays, of length len(image_paths_array)\n",
337
- " scores = [] # an array of ints, of length len(image_paths_array)\n",
338
- " item_analysis = [0] * len(correction_key) # an array of ints\n",
339
- " score_analysis = [0] * (len(correction_key) + 1) # an array of ints\n",
340
- "\n",
341
- " for i in range(len(image_paths_array)):\n",
342
- " process_image(image_paths_array[i], correctness, scores, item_analysis, score_analysis, correction_key)\n",
343
- "\n",
344
- " # Formatting data\n",
345
- " # Define csv file paths\n",
346
- " student_correctness_csv_path = \"student_correctness.csv\"\n",
347
- " student_scores_csv_path = \"student_scores.csv\"\n",
348
- " item_analysis_csv_path = \"item_analysis.csv\"\n",
349
- " score_analysis_csv_path = \"score_analysis.csv\"\n",
350
- " merged_xlsx_path = \"merged_data.xlsx\"\n",
351
- "\n",
352
- " # For correctness\n",
353
- " transposed_data = list(map(list, zip(*correctness))) # Transpose the data to have students as columns\n",
354
- " columns = [f\"Student {i+1}\" for i in range(len(correctness))] # Define the columns\n",
355
- " correctness_df = pd.DataFrame(transposed_data, columns=columns) # Create the DataFrame\n",
356
- " correctness_df.insert(0, \"Item Number\", range(1, len(correctness[0]) + 1)) # Add the item number column\n",
357
- " correctness_df.to_csv(student_correctness_csv_path, index=False) # Save\n",
358
- "\n",
359
- " # For student scores\n",
360
- " columns = [\"Score\"] # Define the columns\n",
361
- " scores_df = pd.DataFrame(scores, columns=columns) # Create the DataFrame\n",
362
- " scores_df.insert(0, \"Student Number\", range(1, len(scores) + 1)) # Add the student number column\n",
363
- " scores_df.to_csv(student_scores_csv_path, index=False) # Save\n",
364
- "\n",
365
- " # For item analysis\n",
366
- " columns = [\"Number of Correct Answers\"] # Define the columns\n",
367
- " item_analysis_df = pd.DataFrame(item_analysis, columns=columns) # Create the DataFrame\n",
368
- " item_analysis_df.insert(0, \"Item Number\", range(1, len(item_analysis) + 1)) # Add the student number column\n",
369
- " item_analysis_df.to_csv(item_analysis_csv_path, index=False) # Save\n",
370
- "\n",
371
- " # For score analysis\n",
372
- " columns = [\"Number of Students\"] # Define the columns\n",
373
- " score_analysis_df = pd.DataFrame(score_analysis, columns=columns) # Create the DataFrame\n",
374
- " score_analysis_df.insert(0, \"Score\", range(0, len(score_analysis))) # Add the student number column\n",
375
- " score_analysis_df.to_csv(score_analysis_csv_path, index=False) # Save\n",
376
- "\n",
377
- " # For merging CSV into into XLSX\n",
378
- " # Create a writer to save multiple dataframes to a single XLSX file\n",
379
- " with pd.ExcelWriter(merged_xlsx_path) as writer:\n",
380
- " # Write each dataframe to a different sheet\n",
381
- " correctness_df.to_excel(writer, sheet_name=\"Correctness\", index=False)\n",
382
- " scores_df.to_excel(writer, sheet_name=\"Scores\", index=False)\n",
383
- " item_analysis_df.to_excel(writer, sheet_name=\"Item Analysis\", index=False)\n",
384
- " score_analysis_df.to_excel(writer, sheet_name=\"Score Analysis\", index=False)\n",
385
- "\n",
386
- " return student_correctness_csv_path, student_scores_csv_path, item_analysis_csv_path, score_analysis_csv_path, merged_xlsx_path\n",
387
- "\n",
388
- "\n",
389
- "# A helper function for readability\n",
390
- "# Takes in an image path, the four arrays to be modified and an array of chars\n",
391
- "# Void return value, modifies the four arrays directly\n",
392
- "def process_image(img_path, img_cor_arr, img_scr_arr, itm_ana_arr, scr_ana_arr, correction_key):\n",
393
- " base_image = cv2.imread(img_path)\n",
394
- "\n",
395
- " # Preprocessing\n",
396
- " segmentation_image, ocr_image, processing_error = preprocess(base_image)\n",
397
- " if processing_error: # Check for exception\n",
398
- " invalid_num_of_items(img_scr_arr, \"PROCESSING ERROR\")\n",
399
- " return\n",
400
- "\n",
401
- " # ROI Extraction\n",
402
- " rois, extraction_error = extract_rois(segmentation_image, ocr_image, len(correction_key))\n",
403
- " if extraction_error: # Check for exception\n",
404
- " invalid_num_of_items(img_scr_arr, \"EXTRACTION ERROR\")\n",
405
- " return\n",
406
- "\n",
407
- " # Classification\n",
408
- " item_answers = classify_rois(rois)\n",
409
- "\n",
410
- " if len(item_answers) != len(correction_key): # Check for exception (extra layer of safety)\n",
411
- " invalid_num_of_items(img_scr_arr, \"NUM OF ITEM ANSWERS != CORRECTION KEY\")\n",
412
- " return\n",
413
- "\n",
414
- " # Grading and Analysis\n",
415
- " grade_and_analyze(item_answers, img_cor_arr, img_scr_arr, itm_ana_arr, scr_ana_arr, correction_key)\n",
416
- "\n",
417
- "\n",
418
- "# Function for error logic: sets the score to -1 (-1 score means invalid image)\n",
419
- "# Takes in the array to be modified (score array)\n",
420
- "# Void return value, modifies the array directly\n",
421
- "def invalid_num_of_items(score_array, error_string):\n",
422
- " score_array.append(-1)\n",
423
- " print(f\"ERROR: {error_string}\")"
424
- ]
425
- },
426
- {
427
- "cell_type": "markdown",
428
- "metadata": {
429
- "id": "NwhaA5O4QsDv"
430
- },
431
- "source": [
432
- "### Image Preprocessing Functions"
433
- ]
434
- },
435
- {
436
- "cell_type": "code",
437
- "execution_count": 25,
438
- "metadata": {
439
- "id": "JzWGfdPG58Cb"
440
- },
441
- "outputs": [],
442
- "source": [
443
- "# Takes in an image (the base image)\n",
444
- "# Returns an image for segmentation, an image for OCR, and a boolean for error handling\n",
445
- "def preprocess(input_image):\n",
446
- " # ------------------------------------ FOR OCR IMAGE------------------------------------ #\n",
447
- " resized_image = cv2.resize(input_image, (IMAGE_WIDTH, IMAGE_HEIGHT)) # MOVED FROM INTERFACE TO HERE\n",
448
- " gray = cv2.cvtColor(resized_image, cv2.COLOR_BGR2GRAY)\n",
449
- " threshed_gray = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, THRESH_BLOCK_SIZE, THRESH_CONSTANT)\n",
450
- " opened = cv2.bitwise_not(cv2.morphologyEx(cv2.bitwise_not(threshed_gray), cv2.MORPH_OPEN, np.ones((2,2),np.uint8), iterations=1))\n",
451
- " ocr_image = remove_lines(opened)\n",
452
- "\n",
453
- " # ----------------------------------- FOR HEADER MASK ----------------------------------- #\n",
454
- " blur = cv2.GaussianBlur(gray, (9,9), 0)\n",
455
- " threshed_blur = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, THRESH_BLOCK_SIZE, THRESH_CONSTANT)\n",
456
- " removed_lines = remove_lines(threshed_blur, 6)\n",
457
- "\n",
458
- " # Determine divider y positions\n",
459
- " intensely_dilated = cv2.dilate(cv2.bitwise_not(removed_lines), cv2.getStructuringElement(cv2.MORPH_RECT, (2000, 40)), iterations=1)\n",
460
- " black_ys = np.where(np.any(intensely_dilated == 0, axis=1))[0].tolist()\n",
461
- " to_be_removed = []\n",
462
- " for i in range(1, len(black_ys)):\n",
463
- " if black_ys[i] - black_ys[i - 1] == 1:\n",
464
- " to_be_removed.append(i)\n",
465
- " stripe_positions = [black_ys[idx] for idx in range(len(black_ys)) if not idx in to_be_removed]\n",
466
- "\n",
467
- " # Determine which stripe is the header border\n",
468
- " num_of_stripe_positions = len(stripe_positions)\n",
469
- " header_border = 0\n",
470
- " match num_of_stripe_positions:\n",
471
- " case 1:\n",
472
- " header_border = stripe_positions[0]\n",
473
- " case 2:\n",
474
- " header_border = stripe_positions[0] if stripe_positions[1] > 500 else stripe_positions[1]\n",
475
- " case _:\n",
476
- " found = False\n",
477
- " for stripe_position in stripe_positions:\n",
478
- " if stripe_position > 100 and stripe_position < 500:\n",
479
- " header_border = stripe_position\n",
480
- " found = True\n",
481
- " continue\n",
482
- " if not found: # If there is no stripe, consider the image invalid (subject to change)\n",
483
- " return [0], [0], True\n",
484
- "\n",
485
- " # Create a mask based on header border\n",
486
- " mask = np.ones(intensely_dilated.shape, dtype=np.uint8) * 255\n",
487
- " mask[header_border:, :] = 0\n",
488
- "\n",
489
- " # ---------------------------------- FOR COMPONENT FINDING IMAGE ---------------------------------- #\n",
490
- " masked_removed_lines = cv2.bitwise_or(removed_lines, mask)\n",
491
- " segmentation_image = cv2.dilate(cv2.bitwise_not(masked_removed_lines), cv2.getStructuringElement(cv2.MORPH_RECT, (10, 6)), iterations=1)\n",
492
- "\n",
493
- " return segmentation_image, ocr_image, False"
494
- ]
495
- },
496
- {
497
- "cell_type": "code",
498
- "execution_count": 26,
499
- "metadata": {
500
- "id": "oxn1iQUX7be6"
501
- },
502
- "outputs": [],
503
- "source": [
504
- "# A helper function for the preprocess function\n",
505
- "# Takes in an image, and optionally line thickness (an int that determines how erased the notepad lines are)\n",
506
- "# Returns a modified image, where the notepad lines are attempted to be erased\n",
507
- "def remove_lines(input_image, thickness=2):\n",
508
- " edges = cv2.Canny(input_image, 200, 200)\n",
509
- " lines = cv2.HoughLines(edges, 1, np.pi / 180, 200)\n",
510
- "\n",
511
- " # Check for when lines are not found\n",
512
- " if len(lines) == 0:\n",
513
- " return input_image\n",
514
- "\n",
515
- " # Creates a mask based on the lines found by HoughLines\n",
516
- " line_mask = np.zeros_like(input_image)\n",
517
- " for line in lines:\n",
518
- " rho, theta = line[0]\n",
519
- " if theta < 1.0 or theta > 2.0:\n",
520
- " continue\n",
521
- " a = np.cos(theta)\n",
522
- " b = np.sin(theta)\n",
523
- " x0 = a * rho\n",
524
- " y0 = b * rho\n",
525
- " x1 = int(x0 + LINE_LENGTH * (-b))\n",
526
- " y1 = int(y0 + LINE_LENGTH * (a))\n",
527
- " x2 = int(x0 - LINE_LENGTH * (-b))\n",
528
- " y2 = int(y0 - LINE_LENGTH * (a))\n",
529
- " cv2.line(line_mask, (x1, y1), (x2, y2), (255, 255, 255), 2)\n",
530
- "\n",
531
- " # Dilates the lines vertically based on thickness parameter\n",
532
- " dilation_kernel = np.ones((thickness, 1),np.uint8)\n",
533
- " line_mask = cv2.dilate(line_mask, dilation_kernel, iterations=1)\n",
534
- "\n",
535
- " # Subtracts the mask from the base image and applies MORPH_OPEN to denoise\n",
536
- " sub_result = cv2.bitwise_or(input_image, line_mask)\n",
537
- " final_result = cv2.bitwise_not(cv2.morphologyEx(cv2.bitwise_not(sub_result), cv2.MORPH_OPEN, np.ones((2, 2),np.uint8), iterations=1))\n",
538
- " return final_result"
539
- ]
540
- },
541
- {
542
- "cell_type": "markdown",
543
- "metadata": {
544
- "id": "61Yx4uiBQ4K0"
545
- },
546
- "source": [
547
- "### Region of Interest Extraction Functions"
548
- ]
549
- },
550
- {
551
- "cell_type": "code",
552
- "execution_count": 27,
553
- "metadata": {
554
- "id": "mgIH1Imp1EQ4"
555
- },
556
- "outputs": [],
557
- "source": [
558
- "# Takes in an image for segmentation, an image for OCR, and the target number of ROIs\n",
559
- "# Returns an array of RIOs (of shape [y, x]), and a boolean for error handling\n",
560
- "def extract_rois(segmentation_image, ocr_image, num_of_items):\n",
561
- " # Get the components in the input image\n",
562
- " _, _, stats, centroids = cv2.connectedComponentsWithStats(segmentation_image, connectivity=4)\n",
563
- "\n",
564
- " # Define all bounds (stats excluding area)\n",
565
- " all_bounds = [stat[:4].tolist() for stat in stats]\n",
566
- "\n",
567
- " # Remove the background from bounds and centroids\n",
568
- " all_bounds.pop(0)\n",
569
- " centroids = centroids.tolist()\n",
570
- " centroids.pop(0)\n",
571
- "\n",
572
- " # Find components that are close to each other and merge them\n",
573
- " nearby_component_pairs = find_nearby_pairs(centroids, MAX_COMPONENT_MERGE_DISTANCE)\n",
574
- " mergeable_components = find_mergeable_components(nearby_component_pairs)\n",
575
- " merged_bounds = merge_groups(mergeable_components, all_bounds)\n",
576
- "\n",
577
- " # Make an array of bounds, exclude bounds that were used in merging and bounds that are too small\n",
578
- " component_bounds = [bound[:4] for index, bound in enumerate(all_bounds) if not any(index in group for group in mergeable_components) and (bound[2] > MIN_COMPONENT_SIDE and bound[3] > MIN_COMPONENT_SIDE)]\n",
579
- "\n",
580
- " # Add the merged bounds\n",
581
- " component_bounds.extend(merged_bounds)\n",
582
- "\n",
583
- " # Sort components into two columns\n",
584
- " component_bounds = sort_into_columns(component_bounds)\n",
585
- "\n",
586
- " # At this point, components in each column typically have one or more components in the same y (y is within Y_EPSILON)\n",
587
- " # Remove components except the ones rightmost in each row\n",
588
- " component_bounds = filter_non_letters(component_bounds)\n",
589
- "\n",
590
- " # Convert bounds to ROIs\n",
591
- " rois = []\n",
592
- " for bound in component_bounds:\n",
593
- " x, y, w, h = bound\n",
594
- " roi_img = ocr_image[y:y+h, x:x+w]\n",
595
- " rois.append(roi_img)\n",
596
- "\n",
597
- " # Handle exception: If the number of ROIs found is not the same as the target, consider the image invalid\n",
598
- " if len(rois) != num_of_items:\n",
599
- " return [0], True\n",
600
- "\n",
601
- " return rois, False"
602
- ]
603
- },
604
- {
605
- "cell_type": "code",
606
- "execution_count": 28,
607
- "metadata": {
608
- "id": "1TrffyMX1XYB"
609
- },
610
- "outputs": [],
611
- "source": [
612
- "# A helper function for the roi extraction function\n",
613
- "# Takes in an array of centroids (component midpoints) and max merge distance\n",
614
- "# Returns an array of tuples, representing nearby pairs of components\n",
615
- "def find_nearby_pairs(centroids_array, max_merge_distance):\n",
616
- " nearby_pairs = []\n",
617
- " num_components = len(centroids_array)\n",
618
- " for i in range(num_components - 1):\n",
619
- " for j in range(i + 1, num_components):\n",
620
- " distance = np.linalg.norm(np.array(centroids_array[i]) - np.array(centroids_array[j]))\n",
621
- " if distance <= max_merge_distance:\n",
622
- " nearby_pairs.append((i, j))\n",
623
- " return nearby_pairs"
624
- ]
625
- },
626
- {
627
- "cell_type": "code",
628
- "execution_count": 29,
629
- "metadata": {
630
- "id": "Y2Ludp0s1f0A"
631
- },
632
- "outputs": [],
633
- "source": [
634
- "# A helper function for the roi extraction function\n",
635
- "# Takes in an array of nearby pairs\n",
636
- "# Returns an array of sets, representing two or more components that are close to each other\n",
637
- "def find_mergeable_components(nearby_pairs):\n",
638
- " groups = []\n",
639
- " for pair in nearby_pairs:\n",
640
- " group_found = False\n",
641
- " for group in groups:\n",
642
- " if any(component in group for component in pair): # Evaluates, for each pair component, if it is in the group\n",
643
- " group.update(pair)\n",
644
- " group_found = True\n",
645
- " break\n",
646
- " if not group_found:\n",
647
- " groups.append(set(pair))\n",
648
- " return groups"
649
- ]
650
- },
651
- {
652
- "cell_type": "code",
653
- "execution_count": 30,
654
- "metadata": {
655
- "id": "iUKF7Ns81hPd"
656
- },
657
- "outputs": [],
658
- "source": [
659
- "# A helper function for the roi extraction function\n",
660
- "# Takes in an array of mergeable groups and all the component bounds\n",
661
- "# Returns an array of the bounds that were the result of merging the mergeable groups\n",
662
- "def merge_groups(groups, bounds):\n",
663
- " merged_bounds = []\n",
664
- " for group in groups:\n",
665
- " min_x = 5000\n",
666
- " min_y = 5000\n",
667
- " max_x = 0\n",
668
- " max_y = 0\n",
669
- " for component in group:\n",
670
- " x, y, w, h = bounds[component]\n",
671
- " min_x = min(min_x, x)\n",
672
- " min_y = min(min_y, y)\n",
673
- " max_x = max(max_x, x+w)\n",
674
- " max_y = max(max_y, y+h)\n",
675
- " merged_bounds.append([min_x, min_y, max_x - min_x, max_y - min_y])\n",
676
- " return merged_bounds"
677
- ]
678
- },
679
- {
680
- "cell_type": "code",
681
- "execution_count": 31,
682
- "metadata": {
683
- "id": "92pqbVg_1iYp"
684
- },
685
- "outputs": [],
686
- "source": [
687
- "# A helper function for the roi extraction function\n",
688
- "# Takes in an array of component bounds\n",
689
- "# Returns an array of component bounds that are sorted into k columns\n",
690
- "def sort_into_columns(component_bounds):\n",
691
- " # Determine the x coordinates of the bounds\n",
692
- " x_coordinates = [bound[0] for bound in component_bounds]\n",
693
- " x_coordinates = np.array(x_coordinates).reshape(-1, 1)\n",
694
- "\n",
695
- " # Set the optimal number of clusters (k) to NUMBER_OF_COLUMNS, as the number of columns is predefined\n",
696
- " optimal_k = NUMBER_OF_COLUMNS\n",
697
- "\n",
698
- " # Perform K-means clustering with k = NUMBER_OF_COLUMNS\n",
699
- " kmeans = KMeans(n_clusters=optimal_k, init=\"k-means++\", max_iter=300, n_init=10, random_state=0)\n",
700
- " kmeans.fit(x_coordinates)\n",
701
- "\n",
702
- " # Group the components based on the cluster assignments\n",
703
- " grouped_components = [[] for _ in range(optimal_k)] # An array of two arrays\n",
704
- " for i, label in enumerate(kmeans.labels_): # kmeans.labels_ is an array of ints representing the label of each component\n",
705
- " grouped_components[label].append(component_bounds[i])\n",
706
- "\n",
707
- " # Sort into a single list\n",
708
- " sorted_components = []\n",
709
- " grouped_components = sorted(grouped_components, key=lambda group: group[0][0])\n",
710
- " for group in grouped_components:\n",
711
- " sorted_group = sorted(group, key=lambda component: component[1])\n",
712
- " sorted_components.extend(sorted_group)\n",
713
- "\n",
714
- " return sorted_components"
715
- ]
716
- },
717
- {
718
- "cell_type": "code",
719
- "execution_count": 32,
720
- "metadata": {
721
- "id": "hMA60xDa1jWt"
722
- },
723
- "outputs": [],
724
- "source": [
725
- "# A helper function for the roi extraction function\n",
726
- "# Takes in an array of component bounds\n",
727
- "# Returns an array of component bounds that excludes positionally considered to be non-letters\n",
728
- "def filter_non_letters(component_bounds):\n",
729
- " # Defines dictionaries mapping components to their origin-x and centroid-y\n",
730
- " comp_x_dict = {}\n",
731
- " for index, bound in enumerate(component_bounds):\n",
732
- " comp_x_dict[index] = bound[0]\n",
733
- "\n",
734
- " comp_cent_y_dict = {}\n",
735
- " for index, bound in enumerate(component_bounds):\n",
736
- " comp_cent_y_dict[index] = bound[1] + (bound[3] / 2)\n",
737
- "\n",
738
- "\n",
739
- " # Function that compares two keys and decides if the current key should be removed\n",
740
- " # Takes in the current key, the key to check against, a dictionary of component origin-x, and a dicitonary of component centroid-y\n",
741
- " # Modifies the array directly; returns True if the current key has been removed\n",
742
- " def check_key_for_removal(curr_key, key_to_check, x_dict, y_dict):\n",
743
- " if not key_to_check in y_dict or not curr_key in y_dict:\n",
744
- " return False\n",
745
- " if abs(y_dict[key_to_check] - y_dict[curr_key]) > Y_EPSILON:\n",
746
- " return False\n",
747
- "\n",
748
- " curr_key_x = x_dict[curr_key]\n",
749
- " key_to_check_x = x_dict[key_to_check]\n",
750
- " if key_to_check_x > curr_key_x:\n",
751
- " y_dict.pop(curr_key)\n",
752
- " return True\n",
753
- "\n",
754
- " # Based on the components centroid-y, determine which components are in the same \"row,\" i.e. components whose centroid-y are within Y_EPSILON\n",
755
- " # If the components have significantly different centroid-y, ignore; else, check if the current one is to the left or to the right relatively\n",
756
- " # If the current component is to the left of some other component in the same \"row,\" remove it from the comp_cent_y_dict\n",
757
- " dup_y_dict = comp_cent_y_dict.copy()\n",
758
- " for key in dup_y_dict:\n",
759
- " prev_prev_key = key - 2\n",
760
- " prev_key = key - 1\n",
761
- " next_key = key + 1\n",
762
- " if not key in comp_cent_y_dict:\n",
763
- " continue\n",
764
- "\n",
765
- " done = check_key_for_removal(key, prev_prev_key, comp_x_dict, comp_cent_y_dict)\n",
766
- " if done: continue # If the curr_key has been removed, go to the next key\n",
767
- "\n",
768
- " done = check_key_for_removal(key, prev_key, comp_x_dict, comp_cent_y_dict)\n",
769
- " if done: continue\n",
770
- "\n",
771
- " done = check_key_for_removal(key, next_key, comp_x_dict, comp_cent_y_dict)\n",
772
- "\n",
773
- " # Create an array of component bounds that only excludes what remains in the comp_cent_y_dict\n",
774
- " filtered_component_bounds = [bound for index, bound in enumerate(component_bounds) if index in comp_cent_y_dict]\n",
775
- "\n",
776
- " return filtered_component_bounds"
777
- ]
778
- },
779
- {
780
- "cell_type": "markdown",
781
- "metadata": {
782
- "id": "JeW50SezRA9U"
783
- },
784
- "source": [
785
- "### Image Classification Functions"
786
- ]
787
- },
788
- {
789
- "cell_type": "code",
790
- "execution_count": 33,
791
- "metadata": {
792
- "id": "_S0V_bbVkknf"
793
- },
794
- "outputs": [],
795
- "source": [
796
- "# Takes in an array of ROIs (numpy arrays)\n",
797
- "# Returns an array of item answers (letter string)\n",
798
- "def classify_rois(rois):\n",
799
- " item_answers = []\n",
800
- " for roi in rois:\n",
801
- " item_answers.append(classify_roi(roi))\n",
802
- " return item_answers"
803
- ]
804
- },
805
- {
806
- "cell_type": "code",
807
- "execution_count": 34,
808
- "metadata": {
809
- "id": "LWsrH1MTRBnX"
810
- },
811
- "outputs": [],
812
- "source": [
813
- "# Takes in an ROI expressed as a numpy array of shape [y, x]\n",
814
- "# Returns the classifiers classification of the ROI\n",
815
- "def classify_roi(roi):\n",
816
- " # Preprocess ROI\n",
817
- " new_array = np.full((28, 28), 255, dtype=np.uint8)\n",
818
- " small_roi = cv2.resize(roi, (26, 26))\n",
819
- " new_array[1:27, 1:27] = small_roi\n",
820
- "\n",
821
- " roi = new_array\n",
822
- " roi = cv2.bitwise_not(roi) # Invert to fit model requirement\n",
823
- " roi = roi / 255.0 # Normalize\n",
824
- "\n",
825
- " roi = torch.from_numpy(roi) # Convert to tensor\n",
826
- " roi = roi.view(1, 1, 28, 28) # Reshape to fit model requirement\n",
827
- " roi = roi.to(torch.float32) # Change data type to fit model requirement\n",
828
- "\n",
829
- " # Classify the ROI\n",
830
- " model.eval()\n",
831
- " output = model(roi) # Output is a tensor with 4 floats representing class probabilities\n",
832
- " prediction = output.argmax(dim=1, keepdim=True).item() # Returns a number from 0 to 3, representing the most probable class\n",
833
- " predicted_letter = PREDICTION_TO_STRING[prediction] # Converts the number to the corresponding letter\n",
834
- "\n",
835
- " return predicted_letter"
836
- ]
837
- },
838
- {
839
- "cell_type": "markdown",
840
- "metadata": {
841
- "id": "3CTZ7GAUikON"
842
- },
843
- "source": [
844
- "### Grading and Analysis Function"
845
- ]
846
- },
847
- {
848
- "cell_type": "code",
849
- "execution_count": 35,
850
- "metadata": {
851
- "id": "RxkhdVWriaz9"
852
- },
853
- "outputs": [],
854
- "source": [
855
- "# Takes in 5 arrays and one string of text\n",
856
- "# Has no return value. Instead, updates the arrays in place.\n",
857
- "def grade_and_analyze(item_answers, img_cor_arr, img_scr_arr, itm_ana_arr, scr_ana_arr, correction_key):\n",
858
- " correction_array = []\n",
859
- " score = 0\n",
860
- "\n",
861
- " for i in range(len(correction_key)):\n",
862
- " if item_answers[i] == correction_key[i] or correction_key[i] == \"X\":\n",
863
- " correction_array.append(1) # Update correction array\n",
864
- " score += 1 # Update score\n",
865
- " itm_ana_arr[i] += 1 # Update item analysis\n",
866
- " else:\n",
867
- " correction_array.append(0) # Update correction array\n",
868
- "\n",
869
- " # Bring changes to the input arrays (except item analysis, which was updated during the loop)\n",
870
- " img_cor_arr.append(correction_array)\n",
871
- " img_scr_arr.append(score)\n",
872
- " scr_ana_arr[score] += 1 # Update score analysis"
873
- ]
874
- },
875
- {
876
- "cell_type": "markdown",
877
- "metadata": {
878
- "id": "aolbZpscRhXM"
879
- },
880
- "source": [
881
- "## THE INTERFACE\n",
882
- "Use Gradio to create a user-friendly interface."
883
- ]
884
- },
885
- {
886
- "cell_type": "code",
887
- "execution_count": 36,
888
- "metadata": {
889
- "id": "1_LR6TyoZxVJ"
890
- },
891
- "outputs": [],
892
- "source": [
893
- "# Download the sample images for the demo interface\n",
894
- "SAMPLE_IMAGES_FOLDER_URL = \"https://raw.githubusercontent.com/GabrielEdradan/HMC-Grad/main/sample_images/\"\n",
895
- "IMAGE_SAVE_PATH = \"/content/\"\n",
896
- "\n",
897
- "NUM_OF_SAMPLES = 2\n",
898
- "for i in range(1, NUM_OF_SAMPLES + 1):\n",
899
- " suffix = f\"sample_{i}.jpg\"\n",
900
- " os.system(f\"wget {SAMPLE_IMAGES_FOLDER_URL}{suffix} -O {IMAGE_SAVE_PATH}{suffix}\")"
901
- ]
902
- },
903
- {
904
- "cell_type": "code",
905
- "execution_count": 37,
906
- "metadata": {
907
- "colab": {
908
- "base_uri": "https://localhost:8080/",
909
- "height": 645
910
- },
911
- "id": "AnBH_XUARhyV",
912
- "outputId": "8983bc54-1953-47a6-950a-0b5a48735d6f"
913
- },
914
- "outputs": [
915
- {
916
- "output_type": "stream",
917
- "name": "stdout",
918
- "text": [
919
- "Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n",
920
- "\n",
921
- "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n",
922
- "Running on public URL: https://65d4d822e6a4e8af8d.gradio.live\n",
923
- "\n",
924
- "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
925
- ]
926
- },
927
- {
928
- "output_type": "display_data",
929
- "data": {
930
- "text/plain": [
931
- "<IPython.core.display.HTML object>"
932
- ],
933
- "text/html": [
934
- "<div><iframe src=\"https://65d4d822e6a4e8af8d.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
935
- ]
936
- },
937
- "metadata": {}
938
- },
939
- {
940
- "output_type": "execute_result",
941
- "data": {
942
- "text/plain": []
943
- },
944
- "metadata": {},
945
- "execution_count": 37
946
- }
947
- ],
948
- "source": [
949
- "# Define the interface function\n",
950
- "# Takes in two inputs: the correction key (string) and the answer sheet images (Array[Array[image paths]])\n",
951
- "def interface_function(correction_key_string, *images):\n",
952
- "\n",
953
- " # Set correction key\n",
954
- " correction_key = []\n",
955
- " for item in correction_key_string:\n",
956
- " if item in [\"A\", \"B\", \"C\", \"D\", \"X\"]:\n",
957
- " correction_key.append(item.upper())\n",
958
- "\n",
959
- " # Run the main function\n",
960
- " return process_image_set(images[0], correction_key)\n",
961
- "\n",
962
- "\n",
963
- "# Define the Gradio interface\n",
964
- "\n",
965
- "# Set the description string\n",
966
- "desc ='''\n",
967
- "This is the demo interface for the research project titled \"HMC-Grad: Automating Handwritten Multiple-Choice Test Grading Using Computer Vision and Deep Learning\" by Edradan, G., Serrano, D., and Tunguia, T.\n",
968
- "\n",
969
- "Instructions:\n",
970
- "\n",
971
- "First Input: Enter the correction key. It must be a continuous string of text, with each letter representing the correct answer for each item in consecutive order. The only letters accepted are A, B, C, and D, but an X can be written to represent an item that would accept any answer (bonus item).\n",
972
- "\n",
973
- "Second Input: Upload the images of the papers to be evaluated and analyzed. The order of evaluation is based on the order in which the images are uploaded.\n",
974
- "\n",
975
- "For better results, the following are recommended:\n",
976
- "\n",
977
- "Regarding the documents to be evaluated:\n",
978
- " - Have substantial left and right margins\n",
979
- " - Provide a blank line between the header and the answers\n",
980
- " - Write the answers in capitals and in two columns\n",
981
- " - Avoid having the answers overlap with the notepad lines\n",
982
- " - Have significant space between the numbers and the letters\n",
983
- " - Write the item numbers smaller than the letters\n",
984
- "\n",
985
- "Regarding the photo:\n",
986
- " - Have an aspect ratio of 3:4\n",
987
- " - Have a resolution of at least 1125 px by 1500 px\n",
988
- " - Have adequate lighting; use flash if necessary\n",
989
- "\n",
990
- "'''\n",
991
- "\n",
992
- "interface = gr.Interface(\n",
993
- " fn=interface_function,\n",
994
- " title=\"HMC-Grad: Handwritten Multiple-Choice Test Grader\",\n",
995
- " description=desc,\n",
996
- " allow_flagging=\"never\",\n",
997
- " inputs=[gr.Textbox(label=\"Correction Key\", placeholder=\"ABCDABCDABCD\"),\n",
998
- " gr.File(file_count=\"multiple\", file_types=[\".jpg\"], label=\"Upload Image(s)\")],\n",
999
- " outputs=[gr.File(label=\"Correctness\"),\n",
1000
- " gr.File(label=\"Scores\"),\n",
1001
- " gr.File(label=\"Item Analysis\"),\n",
1002
- " gr.File(label=\"Score Analysis\"),\n",
1003
- " gr.File(label=\"Merged Data\"),\n",
1004
- " ],\n",
1005
- " examples=[\n",
1006
- " [\"ABCDXABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDA\",([\"sample_1.jpg\"])],\n",
1007
- " [\"ABCDXABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDA\",([\"sample_2.jpg\"])],\n",
1008
- " [\"ABCDXABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDABCDA\",([\"sample_1.jpg\", \"sample_2.jpg\"])],\n",
1009
- " ]\n",
1010
- ")\n",
1011
- "\n",
1012
- "# Launch the interface\n",
1013
- "interface.launch()"
1014
- ]
1015
- }
1016
- ],
1017
- "metadata": {
1018
- "colab": {
1019
- "provenance": [],
1020
- "toc_visible": true
1021
- },
1022
- "kernelspec": {
1023
- "display_name": "Python 3",
1024
- "name": "python3"
1025
- },
1026
- "language_info": {
1027
- "name": "python"
1028
- }
1029
- },
1030
- "nbformat": 4,
1031
- "nbformat_minor": 0
1032
- }