rzimmerdev commited on
Commit
6ec3bf6
β€’
1 Parent(s): 590d01a

Finished MNIST downloading and caching modules

Browse files
notebooks/dataloader.ipynb DELETED
@@ -1,198 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {
7
- "collapsed": true,
8
- "pycharm": {
9
- "name": "#%%\n"
10
- }
11
- },
12
- "outputs": [
13
- {
14
- "name": "stdout",
15
- "output_type": "stream",
16
- "text": [
17
- "/mnt/c/Users/rzimm/Workspace/data/zero-to-hero\n"
18
- ]
19
- }
20
- ],
21
- "source": [
22
- "%cd ..\n",
23
- "%load_ext autoreload\n",
24
- "%autoreload 2\n",
25
- "from datasets import downloader"
26
- ]
27
- },
28
- {
29
- "cell_type": "code",
30
- "execution_count": 56,
31
- "outputs": [],
32
- "source": [
33
- "import pandas as pd\n",
34
- "import numpy as np\n",
35
- "import random\n",
36
- "from glob import glob, escape\n",
37
- "import imageio.v2 as imageio"
38
- ],
39
- "metadata": {
40
- "collapsed": false,
41
- "pycharm": {
42
- "name": "#%%\n"
43
- }
44
- }
45
- },
46
- {
47
- "cell_type": "code",
48
- "execution_count": null,
49
- "outputs": [],
50
- "source": [
51
- "# download.download(\"cityscapes\", \"datasets/downloaded\")"
52
- ],
53
- "metadata": {
54
- "collapsed": false,
55
- "pycharm": {
56
- "name": "#%%\n",
57
- "is_executing": true
58
- }
59
- }
60
- },
61
- {
62
- "cell_type": "code",
63
- "execution_count": 41,
64
- "outputs": [],
65
- "source": [
66
- "def load_dataset(name=\"gtFine\", path=\"datasets/downloads/\"):\n",
67
- " src = path+name\n",
68
- " test, train, val = [f\"{src}/{subpath}\" for subpath in [\"test\", \"train\", \"val\"]]\n",
69
- "\n",
70
- " dataset = {\"test\": glob(test + \"/*/*\"), \"train\": glob(train + \"/*/*\"), \"val\": glob(val + \"/*/*\")}\n",
71
- "\n",
72
- " return dataset"
73
- ],
74
- "metadata": {
75
- "collapsed": false,
76
- "pycharm": {
77
- "name": "#%%\n"
78
- }
79
- }
80
- },
81
- {
82
- "cell_type": "code",
83
- "execution_count": 44,
84
- "outputs": [
85
- {
86
- "data": {
87
- "text/plain": "list"
88
- },
89
- "execution_count": 44,
90
- "metadata": {},
91
- "output_type": "execute_result"
92
- }
93
- ],
94
- "source": [
95
- "type(load_dataset()[\"train\"])"
96
- ],
97
- "metadata": {
98
- "collapsed": false,
99
- "pycharm": {
100
- "name": "#%%\n"
101
- }
102
- }
103
- },
104
- {
105
- "cell_type": "code",
106
- "execution_count": 45,
107
- "outputs": [],
108
- "source": [
109
- "a = [1, 2, 3]"
110
- ],
111
- "metadata": {
112
- "collapsed": false,
113
- "pycharm": {
114
- "name": "#%%\n"
115
- }
116
- }
117
- },
118
- {
119
- "cell_type": "code",
120
- "execution_count": 143,
121
- "outputs": [],
122
- "source": [
123
- "class DataLoader:\n",
124
- " def __init__(self, data):\n",
125
- " self.data = np.array(data)\n",
126
- " self.total = len(self.data)\n",
127
- " self.__items = self.data\n",
128
- " self.__remaining = len(self.data)\n",
129
- " def __next__(self, n=1):\n",
130
- " if n > self.total:\n",
131
- " raise ValueError(f\"Dataset doesn't have enough elements to suffice request of {n} elements.\")\n",
132
- " if self.__remaining > 0:\n",
133
- " indices = random.sample(range(self.__remaining), n)\n",
134
- " sampled = self.__items[indices]\n",
135
- " self.__items = np.delete(self.__items, indices)\n",
136
- " self.__remaining -= n\n",
137
- " return sampled\n",
138
- " else:\n",
139
- " self.__items = self.data\n",
140
- " self.__remaining = len(self.data)\n",
141
- " return self.__next__(n)"
142
- ],
143
- "metadata": {
144
- "collapsed": false,
145
- "pycharm": {
146
- "name": "#%%\n"
147
- }
148
- }
149
- },
150
- {
151
- "cell_type": "code",
152
- "execution_count": 144,
153
- "outputs": [],
154
- "source": [
155
- "loader = DataLoader(a)"
156
- ],
157
- "metadata": {
158
- "collapsed": false,
159
- "pycharm": {
160
- "name": "#%%\n"
161
- }
162
- }
163
- },
164
- {
165
- "cell_type": "code",
166
- "execution_count": null,
167
- "outputs": [],
168
- "source": [],
169
- "metadata": {
170
- "collapsed": false,
171
- "pycharm": {
172
- "name": "#%%\n"
173
- }
174
- }
175
- }
176
- ],
177
- "metadata": {
178
- "kernelspec": {
179
- "display_name": "Python 3",
180
- "language": "python",
181
- "name": "python3"
182
- },
183
- "language_info": {
184
- "codemirror_mode": {
185
- "name": "ipython",
186
- "version": 2
187
- },
188
- "file_extension": ".py",
189
- "mimetype": "text/x-python",
190
- "name": "python",
191
- "nbconvert_exporter": "python",
192
- "pygments_lexer": "ipython2",
193
- "version": "2.7.6"
194
- }
195
- },
196
- "nbformat": 4,
197
- "nbformat_minor": 0
198
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/dataset.ipynb ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "outputs": [],
7
+ "source": [
8
+ "import os\n",
9
+ "import gzip"
10
+ ],
11
+ "metadata": {
12
+ "collapsed": false,
13
+ "pycharm": {
14
+ "name": "#%%\n"
15
+ }
16
+ }
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 3,
21
+ "outputs": [
22
+ {
23
+ "name": "stderr",
24
+ "output_type": "stream",
25
+ "text": [
26
+ "/home/rzimmerdev/conda/envs/data/lib/python3.9/site-packages/_distutils_hack/__init__.py:33: UserWarning: Setuptools is replacing distutils.\n",
27
+ " warnings.warn(\"Setuptools is replacing distutils.\")\n"
28
+ ]
29
+ }
30
+ ],
31
+ "source": [
32
+ "from src.downloader import download_dataset\n",
33
+ "download_dataset(\"mnist\", \"../datasets/mnist\")"
34
+ ],
35
+ "metadata": {
36
+ "collapsed": false,
37
+ "pycharm": {
38
+ "name": "#%%\n"
39
+ }
40
+ }
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 4,
45
+ "outputs": [
46
+ {
47
+ "data": {
48
+ "text/plain": "b'\\x00\\x00\\x08\\x03\\x00\\x00\\xea`\\x00\\x00\\x00\\x1c\\x00\\x00\\x00\\x1c'"
49
+ },
50
+ "execution_count": 4,
51
+ "metadata": {},
52
+ "output_type": "execute_result"
53
+ }
54
+ ],
55
+ "source": [
56
+ "f = gzip.open(\"../datasets/mnist/\" + os.listdir(\"../datasets/mnist/\")[0], 'r')\n",
57
+ "f.read(16)"
58
+ ],
59
+ "metadata": {
60
+ "collapsed": false,
61
+ "pycharm": {
62
+ "name": "#%%\n"
63
+ }
64
+ }
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 5,
69
+ "outputs": [],
70
+ "source": [
71
+ "import numpy as np\n",
72
+ "from torch.utils.data import DataLoader, Dataset"
73
+ ],
74
+ "metadata": {
75
+ "collapsed": false,
76
+ "pycharm": {
77
+ "name": "#%%\n"
78
+ }
79
+ }
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 6,
84
+ "outputs": [],
85
+ "source": [
86
+ "class DatasetMNIST(Dataset):\n",
87
+ " def __init__(self, images, labels):\n",
88
+ " with gzip.open(images, 'r') as f:\n",
89
+ " f.read(4)\n",
90
+ " self.total = int.from_bytes(f.read(4), 'big')\n",
91
+ " rows = int.from_bytes(f.read(4), 'big')\n",
92
+ " columns = int.from_bytes(f.read(4), 'big')\n",
93
+ "\n",
94
+ " image_data = f.read()\n",
95
+ " images = np.frombuffer(image_data, dtype=np.uint8)\\\n",
96
+ " .reshape((self.total, rows, columns))\n",
97
+ " self.images = images\n",
98
+ " with gzip.open(labels, 'r') as f:\n",
99
+ " f.read(4)\n",
100
+ " total = int.from_bytes(f.read(4), 'big')\n",
101
+ "\n",
102
+ " label_data = f.read()\n",
103
+ " labels = np.frombuffer(label_data, dtype=np.uint8)\n",
104
+ " self.labels = labels\n",
105
+ " self.data = list(zip(self.images, self.labels))\n",
106
+ " def __getitem__(self, n):\n",
107
+ " if n > self.total:\n",
108
+ " raise ValueError(f\"Dataset doesn't have enough elements to suffice request of {n} elements.\")\n",
109
+ " return self.data[n]\n",
110
+ "\n",
111
+ " def __len__(self):\n",
112
+ " return len(self.data)"
113
+ ],
114
+ "metadata": {
115
+ "collapsed": false,
116
+ "pycharm": {
117
+ "name": "#%%\n"
118
+ }
119
+ }
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 7,
124
+ "outputs": [],
125
+ "source": [
126
+ "dataset_dir = \"../datasets/mnist/\"\n",
127
+ "loader = DatasetMNIST(dataset_dir + \"train_images\", dataset_dir + \"train_labels\")"
128
+ ],
129
+ "metadata": {
130
+ "collapsed": false,
131
+ "pycharm": {
132
+ "name": "#%%\n"
133
+ }
134
+ }
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": 8,
139
+ "outputs": [
140
+ {
141
+ "data": {
142
+ "text/plain": "<Figure size 432x288 with 1 Axes>",
143
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAASY0lEQVR4nO3de9BU9X3H8ffHu4IGKZcgigZk2qi1JjKaTqjQ8VpnVNR6a42AmWBMYpNMbLXUqhm1STrV1GljJogWFJRg1HiJrTI2iqQpkaSoCBjUYkQIqEhUojHAt3+c83ROHnfPLnuX3+c1s/Ps7nfPOV+W/ew5e86e/SkiMLMd307dbsDMOsNhN0uEw26WCIfdLBEOu1kiHHazRDjstl0kXS1pTpXaRElr6pzPFEmLGuyh4WlT5rC3kKTHJL0hafcOLe8gSSFpl+2YZrWk49rZV6+RdIqkZZLelvRfkg7pdk/d4LC3iKSDgD8BAji1u91YH0ljgbnAZ4FBwAPA/dvzBrmjcNhb5wLgv4FZwORiQdIsSd+S9ANJb0laLGlMoR6SPitpVb5l8C1Jyms7SbpC0kuSNki6TdKH8kkX5n835WutP5Y0RtJ/Snpd0muS5koalM/rdmAU8ED++L/J7/9EvsbbJOkpSRMLvX1E0uN53wuAIfU+IZIul/RCPu1ySae//yH6F0m/krRS0rGFwock3SJpnaRXJF0raed6l11wIvBERCyKiC3AN4CRwIQG5vXBFhG+tOACPA98DjgS+C0wvFCbBWwEjgJ2IVvTzCvUA3iQbM0zCngVOCmvXZjPezQwELgHuD2vHZRPu0thXgcDxwO7A0PJ3hD+uVBfDRxXuD0SeB04mezN//j89tC8/mPghnx+xwBvAXOqPAcTgTWF22cB++XzPQfYDIzIa1OALcCXgV3z+q+AwXn9+8B3gAHAMOAnwEWFaRcVlvMgcHmVni4BHirc3hl4F/hit18zHX+NdruBHeECjM8DPiS/vRL4cqE+C5hZuH0ysLJwO4Dxhdvz+168wKPA5wq138+XtUulsFfobRLwP4Xb/cN+Wd+bR+G+h8m2TkblgRxQqN1Rb9gr1JcCp+XXpwBrARXqPwE+BQwHfgPsWaidB/ywMO2iasvpt8w/yN9kJgK7AX8PbAP+ttuvm05fvBnfGpOBRyLitfz2HfTblAd+Wbj+a7K1dD31/YCXCrWXyII+vFIjkoZJmpdv+r4JzKF80/tA4Kx8E36TpE1kb14j8mW/ERGb+y2/LpIukLS0MN/D+vXySuSJLMx7v7ynXYF1hWm/Q7aG3y4RsZLs/+JfgXX58pcDdR012JEkt5Oi1STtCZwN7CypL7C7A4Mk/VFEPNXkItaSvfj79K1t15Ntgvf3NbK1/eER8bqkSWQv9D79T3N8mWzN/pn+M5J0ILCvpAGFwI+qMI/3yae9GTgW+HFEbJW0FFDhYSMlqRD4UcD9eU+/IdtS2lJrWbVExPeA7+V9DSL7aPRks/P9oPGavXmTgK3AIcAR+eWjwBNkO+2adSfw5XxH2UDgH4Dv5iF4lWyTdHTh8XsDb5PttBsJ/HW/+a3v9/g5wCmSTpS0s6Q98uPl+0fES8AS4KuSdpM0Hjilzr4HkL0pvAogaSrZmr1oGPBXknaVdBbZ8/ZQRKwDHgGul7RPvpNyjKSGdqpJOjL/tw0l20J4IF/jJ8Vhb95k4N8i4hcR8cu+C9na9C9bcIjnVuB2sh1t/0u2c+kSgIj4NXAd8KN8c/cTwFeBj5Pt7PoB2Q69oq8BV+SPvzQiXgZOA6aTBfNlsjeIvtfGXwBHk+1gvAq4rZ6mI2I5cD3ZDr71wB8CP+r3sMXAWOC1/N/x5xHxel67gOwz9nLgDbI184hKy5L075Kml7RzI7AJeC7/+76tmBTodz8ymdmOymt2s0Q47GaJcNjNEuGwmyWio8fZJXlvoFmbRYQq3d/Uml3SSZKek/S8pMubmZeZtVfDh97yM5B+TnbixBqybySdlx9frTaN1+xmbdaONftRwPMR8WJEvAfMI/tyhpn1oGbCPpLs21Z91lDhu9qSpklaImlJE8sysyY1s4Ou0qbC+zbTI2IGMAO8GW/WTc2s2dcABxRu7092hpaZ9aBmwv4kMDY/G2s34Fyy0xPNrAc1vBkfEVskfYHsV012Bm6NiGdb1pmZtVRHz3rzZ3az9mvLl2rM7IPDYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q0PD47gKTVwFvAVmBLRIxrRVNm1npNhT33pxHxWgvmY2Zt5M14s0Q0G/YAHpH0U0nTKj1A0jRJSyQtaXJZZtYERUTjE0v7RcRaScOABcAlEbGw5PGNL8zM6hIRqnR/U2v2iFib/90A3Asc1cz8zKx9Gg67pAGS9u67DpwALGtVY2bWWs3sjR8O3Cupbz53RMR/tKQrM2u5pj6zb/fC/JndrO3a8pndzD44HHazRDjsZolw2M0S4bCbJaIVJ8JYDzv66KNL6+eff35pfcKECaX1Qw89dLt76nPppZeW1teuXVtaHz9+fGl9zpw5VWuLFy8unXZH5DW7WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIn/W2AzjnnHOq1m688cbSaYcMGVJaz09hruqxxx4rrQ8dOrRq7ZBDDimdtpZavd11111Va+eee25Ty+5lPuvNLHEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEz2fvAbvsUv7fMG5c+eC4N998c9XaXnvtVTrtwoVVB/AB4JprrimtL1q0qLS+++67V63Nnz+/dNoTTjihtF7LkiUecazIa3azRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBE+zt4Dav12+8yZMxue94IFC0rrZefCA7z55psNL7vW/Js9jr5mzZrS+uzZs5ua/46m5ppd0q2SNkhaVrhvsKQFklblf/dtb5tm1qx6NuNnASf1u+9y4NGIGAs8mt82sx5WM+wRsRDY2O/u04C+baTZwKTWtmVmrdboZ/bhEbEOICLWSRpW7YGSpgHTGlyOmbVI23fQRcQMYAb4ByfNuqnRQ2/rJY0AyP9uaF1LZtYOjYb9fmByfn0ycF9r2jGzdqn5u/GS7gQmAkOA9cBVwPeB+cAo4BfAWRHRfydepXkluRlf65zw6dOnl9Zr/R/ddNNNVWtXXHFF6bTNHkevZcWKFVVrY8eObWreZ555Zmn9vvvSXAdV+934mp/ZI+K8KqVjm+rIzDrKX5c1S4TDbpYIh90sEQ67WSIcdrNE+BTXFrjyyitL67UOrb333nul9Ycffri0ftlll1WtvfPOO6XT1rLHHnuU1mudpjpq1KiqtVpDLl977bWl9VQPrTXKa3azRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBE1T3Ft6cI+wKe4Dho0qGpt5cqVpdMOGTKktP7ggw+W1idNmlRab8bBBx9cWp87d25p/cgjj2x42XfffXdp/cILLyytb968ueFl78iqneLqNbtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulggfZ6/TsGFVR7hi7dq1Tc179OjRpfV33323tD516tSqtVNPPbV02sMOO6y0PnDgwNJ6rddPWf2MM84onfaBBx4orVtlPs5uljiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCx9nrVHY+e9mwxABDhw4trdf6/fR2/h/V+o5Ard5GjBhRWn/11VcbntYa0/Bxdkm3StogaVnhvqslvSJpaX45uZXNmlnr1bMZPws4qcL934yII/LLQ61ty8xarWbYI2IhsLEDvZhZGzWzg+4Lkp7ON/P3rfYgSdMkLZG0pIllmVmTGg37t4ExwBHAOuD6ag+MiBkRMS4ixjW4LDNrgYbCHhHrI2JrRGwDbgaOam1bZtZqDYVdUvGYyenAsmqPNbPeUHN8dkl3AhOBIZLWAFcBEyUdAQSwGriofS32hk2bNlWt1fpd91q/Cz948ODS+gsvvFBaLxunfNasWaXTbtxYvu913rx5pfVax8prTW+dUzPsEXFehbtvaUMvZtZG/rqsWSIcdrNEOOxmiXDYzRLhsJsloubeeKtt8eLFpfVap7h20zHHHFNanzBhQml927ZtpfUXX3xxu3uy9vCa3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhI+zJ27PPfcsrdc6jl7rZ659imvv8JrdLBEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEh2y2Ulu3bi2t13r9lP3UdNlwzta4hodsNrMdg8NulgiH3SwRDrtZIhx2s0Q47GaJcNjNElHPkM0HALcBHwa2ATMi4kZJg4HvAgeRDdt8dkS80b5WrR1OPPHEbrdgHVLPmn0L8JWI+CjwCeDzkg4BLgcejYixwKP5bTPrUTXDHhHrIuJn+fW3gBXASOA0YHb+sNnApDb1aGYtsF2f2SUdBHwMWAwMj4h1kL0hAMNa3p2ZtUzdv0EnaSBwN/CliHhTqvj120rTTQOmNdaembVKXWt2SbuSBX1uRNyT371e0oi8PgLYUGnaiJgREeMiYlwrGjazxtQMu7JV+C3Aioi4oVC6H5icX58M3Nf69sysVerZjP8k8CngGUlL8/umA18H5kv6NPAL4Ky2dGhtNXr06G63YB1SM+wRsQio9gH92Na2Y2bt4m/QmSXCYTdLhMNulgiH3SwRDrtZIhx2s0R4yObEPfHEE6X1nXYqXx/UGtLZeofX7GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZInycPXHLli0rra9ataq0Xut8+DFjxlStecjmzvKa3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhCKicwuTOrcwa4kpU6aU1mfOnFlaf/zxx6vWLrnkktJply9fXlq3yiKi4k+/e81ulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyWi5nF2SQcAtwEfBrYBMyLiRklXA58B+k5Knh4RD9WYl4+zf8Dss88+pfX58+eX1o877riqtXvuuad02qlTp5bWN2/eXFpPVbXj7PX8eMUW4CsR8TNJewM/lbQgr30zIv6pVU2aWfvUDHtErAPW5dffkrQCGNnuxsystbbrM7ukg4CPAYvzu74g6WlJt0rat8o00yQtkbSkuVbNrBl1h13SQOBu4EsR8SbwbWAMcATZmv/6StNFxIyIGBcR45pv18waVVfYJe1KFvS5EXEPQESsj4itEbENuBk4qn1tmlmzaoZdkoBbgBURcUPh/hGFh50OlP9MqZl1VT2H3sYDTwDPkB16A5gOnEe2CR/AauCifGde2bx86G0HU+vQ3HXXXVe1dvHFF5dOe/jhh5fWfQpsZQ0feouIRUCliUuPqZtZb/E36MwS4bCbJcJhN0uEw26WCIfdLBEOu1ki/FPSZjsY/5S0WeIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpaIen5dtpVeA14q3B6S39eLerW3Xu0L3FujWtnbgdUKHf1SzfsWLi3p1d+m69XeerUvcG+N6lRv3ow3S4TDbpaIbod9RpeXX6ZXe+vVvsC9NaojvXX1M7uZdU631+xm1iEOu1kiuhJ2SSdJek7S85Iu70YP1UhaLekZSUu7PT5dPobeBknLCvcNlrRA0qr8b8Ux9rrU29WSXsmfu6WSTu5SbwdI+qGkFZKelfTF/P6uPnclfXXkeev4Z3ZJOwM/B44H1gBPAudFRE/84r+k1cC4iOj6FzAkHQO8DdwWEYfl9/0jsDEivp6/Ue4bEZf1SG9XA293exjvfLSiEcVhxoFJwBS6+NyV9HU2HXjeurFmPwp4PiJejIj3gHnAaV3oo+dFxEJgY7+7TwNm59dnk71YOq5Kbz0hItZFxM/y628BfcOMd/W5K+mrI7oR9pHAy4Xba+it8d4DeETSTyVN63YzFQzvG2Yr/zusy/30V3MY707qN8x4zzx3jQx/3qxuhL3S72P10vG/T0bEx4E/Az6fb65afeoaxrtTKgwz3hMaHf68Wd0I+xrggMLt/YG1XeijoohYm//dANxL7w1Fvb5vBN3874Yu9/P/emkY70rDjNMDz103hz/vRtifBMZK+oik3YBzgfu70Mf7SBqQ7zhB0gDgBHpvKOr7gcn59cnAfV3s5Xf0yjDe1YYZp8vPXdeHP4+Ijl+Ak8n2yL8A/F03eqjS12jgqfzybLd7A+4k26z7LdkW0aeB3wMeBVblfwf3UG+3kw3t/TRZsEZ0qbfxZB8NnwaW5peTu/3clfTVkefNX5c1S4S/QWeWCIfdLBEOu1kiHHazRDjsZolw2M0S4bCbJeL/AHyD7vpJDzRWAAAAAElFTkSuQmCC\n"
144
+ },
145
+ "metadata": {
146
+ "needs_background": "light"
147
+ },
148
+ "output_type": "display_data"
149
+ }
150
+ ],
151
+ "source": [
152
+ "import matplotlib.pyplot as plt\n",
153
+ "X, y = loader[4]\n",
154
+ "plt.imshow(X, cmap=\"gray\")\n",
155
+ "plt.title(label=\"Annotated label: \" + str(y))\n",
156
+ "plt.show()"
157
+ ],
158
+ "metadata": {
159
+ "collapsed": false,
160
+ "pycharm": {
161
+ "name": "#%%\n"
162
+ }
163
+ }
164
+ }
165
+ ],
166
+ "metadata": {
167
+ "kernelspec": {
168
+ "display_name": "Python 3",
169
+ "language": "python",
170
+ "name": "python3"
171
+ },
172
+ "language_info": {
173
+ "codemirror_mode": {
174
+ "name": "ipython",
175
+ "version": 2
176
+ },
177
+ "file_extension": ".py",
178
+ "mimetype": "text/x-python",
179
+ "name": "python",
180
+ "nbconvert_exporter": "python",
181
+ "pygments_lexer": "ipython2",
182
+ "version": "2.7.6"
183
+ }
184
+ },
185
+ "nbformat": 4,
186
+ "nbformat_minor": 0
187
+ }
src/dataset.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import gzip
4
+
5
+ from src.downloader import download_dataset
6
+
7
+ import numpy as np
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ def load_mnist(download_dir):
12
+ download_dataset("mnist", download_dir)
13
+
14
+ return {"train": (download_dir + "train_images", download_dir + "train_labels"),
15
+ "test": (download_dir + "test_images", download_dir + "test_labels")}
16
+
17
+
18
+ class DatasetMNIST(Dataset):
19
+ def __init__(self, images, labels):
20
+ with gzip.open(images, 'r') as f:
21
+ f.read(4)
22
+ self.total = int.from_bytes(f.read(4), 'big')
23
+ rows = int.from_bytes(f.read(4), 'big')
24
+ columns = int.from_bytes(f.read(4), 'big')
25
+
26
+ image_data = f.read()
27
+ images = np.frombuffer(image_data, dtype=np.uint8).reshape((self.total, rows, columns))
28
+ self.images = images
29
+ with gzip.open(labels, 'r') as f:
30
+ f.read(8)
31
+
32
+ label_data = f.read()
33
+ labels = np.frombuffer(label_data, dtype=np.uint8)
34
+ self.labels = labels
35
+ self.data = list(zip(self.images, self.labels))
36
+
37
+ def __getitem__(self, n):
38
+ if n > self.total:
39
+ raise ValueError(f"Dataset doesn't have enough elements to suffice request of {n} elements.")
40
+ return self.data[n]
41
+
42
+ def __len__(self):
43
+ return len(self.data)
44
+
45
+
46
+ if __name__ == "__main__":
47
+ download_dir = "../downloads/mnist/"
48
+ mnist = load_mnist(download_dir)
49
+
50
+ dataset = DatasetMNIST(*mnist["train"])
51
+
52
+ import matplotlib.pyplot as plt
53
+
54
+ X, y = dataset[4]
55
+ plt.imshow(X, cmap="gray")
56
+ plt.title(label="Annotated label: " + str(y))
57
+ plt.show()
{datasets β†’ src}/downloader.py RENAMED
@@ -4,26 +4,38 @@
4
  # To learn more about the dataset, access:
5
  # https://www.cityscapes-dataset.com/
6
  import os
7
- import sys
8
  import pip
 
9
 
10
 
11
- # Download and cache dataset
12
- def main():
13
- pass
14
-
15
-
16
- def download(name='cityscapes', path='datasets/downloads'):
17
- """Select one of the available and implemented datasets to download:
18
  name=any(['cityscapes', 'camvid', 'labelme'])
19
  """
20
  if name == 'cityscapes':
21
  download_cityscapes(path)
 
 
22
  else:
23
  raise NotImplementedError
24
 
25
 
26
- def download_cityscapes(path='datasets/downloads'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  if hasattr(pip, 'main'):
28
  pip.main(['install', 'cityscapesscripts'])
29
  else:
@@ -36,7 +48,3 @@ def download_cityscapes(path='datasets/downloads'):
36
  print("Invalid dataset name. Please try again.")
37
  ds_name = input()
38
  os.system(f"csDownload {ds_name} -d {path}/{ds_name}")
39
-
40
-
41
- if __name__ == "__main__":
42
- main()
 
4
  # To learn more about the dataset, access:
5
  # https://www.cityscapes-dataset.com/
6
  import os
 
7
  import pip
8
+ from urllib.request import urlretrieve
9
 
10
 
11
+ def download_dataset(name='cityscapes', path='downloads/downloads'):
12
+ """Select one of the available and implemented downloads to download:
 
 
 
 
 
13
  name=any(['cityscapes', 'camvid', 'labelme'])
14
  """
15
  if name == 'cityscapes':
16
  download_cityscapes(path)
17
+ elif name == "mnist":
18
+ pass
19
  else:
20
  raise NotImplementedError
21
 
22
 
23
+ def download_mnist(path="downloads/mnist"):
24
+ remote_files = {"train_images": "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
25
+ "train_labels": "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
26
+ "test_images": "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
27
+ "test_labels": "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"}
28
+ if not os.path.exists(path):
29
+ os.makedirs(path)
30
+
31
+ for file in remote_files.keys():
32
+ if os.path.exists(path + "/" + file):
33
+ continue
34
+
35
+ urlretrieve(remote_files[file], path + "/" + file)
36
+
37
+
38
+ def download_cityscapes(path='downloads/cityscapes'):
39
  if hasattr(pip, 'main'):
40
  pip.main(['install', 'cityscapesscripts'])
41
  else:
 
48
  print("Invalid dataset name. Please try again.")
49
  ds_name = input()
50
  os.system(f"csDownload {ds_name} -d {path}/{ds_name}")