Spaces:
Running
Running
Commit ·
1f2f6bf
1
Parent(s): 5e54979
Add refactored version of cnn visualizer
Browse files- README.md +73 -9
- dist/assets/index-BoCAEVud.css +0 -1
- dist/assets/index-CQRg13xj.css +1 -0
- dist/assets/{index-CARha6nB.js → index-p2vRWBSG.js} +0 -0
- dist/index.html +3 -3
- eslint.config.js +3 -9
- index.html +2 -2
- package-lock.json +0 -0
- package.json +10 -3
- src/App.tsx +27 -0
- src/ConvolutionVisualizer.tsx +406 -0
- src/InfoViewer.tsx +434 -0
- src/NetworkVisualizer.tsx +452 -0
- src/datasets.ts +0 -0
- src/index.css +1 -66
- src/kernels.ts +193 -0
- src/main.tsx +10 -0
- src/mnist.d.ts +40 -0
- src/mnist.js +195 -0
- src/train.ts +437 -0
- src/types.d.ts +1 -0
- src/ui/Button.tsx +15 -0
- src/ui/Dropdown.tsx +25 -0
- src/ui/InputField.tsx +46 -0
- src/ui/LoadingScreen.tsx +19 -0
- src/ui/Radio.tsx +27 -0
- src/ui/Tabs.tsx +25 -0
- src/useConvolutionProcessing.ts +198 -0
- tsconfig.app.json +28 -0
- tsconfig.json +11 -0
- tsconfig.node.json +26 -0
- vite.config.js → vite.config.ts +5 -1
README.md
CHANGED
|
@@ -1,9 +1,73 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# React + TypeScript + Vite
|
| 2 |
+
|
| 3 |
+
This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules.
|
| 4 |
+
|
| 5 |
+
Currently, two official plugins are available:
|
| 6 |
+
|
| 7 |
+
- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react) uses [Babel](https://babeljs.io/) (or [oxc](https://oxc.rs) when used in [rolldown-vite](https://vite.dev/guide/rolldown)) for Fast Refresh
|
| 8 |
+
- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh
|
| 9 |
+
|
| 10 |
+
## React Compiler
|
| 11 |
+
|
| 12 |
+
The React Compiler is not enabled on this template because of its impact on dev & build performances. To add it, see [this documentation](https://react.dev/learn/react-compiler/installation).
|
| 13 |
+
|
| 14 |
+
## Expanding the ESLint configuration
|
| 15 |
+
|
| 16 |
+
If you are developing a production application, we recommend updating the configuration to enable type-aware lint rules:
|
| 17 |
+
|
| 18 |
+
```js
|
| 19 |
+
export default defineConfig([
|
| 20 |
+
globalIgnores(['dist']),
|
| 21 |
+
{
|
| 22 |
+
files: ['**/*.{ts,tsx}'],
|
| 23 |
+
extends: [
|
| 24 |
+
// Other configs...
|
| 25 |
+
|
| 26 |
+
// Remove tseslint.configs.recommended and replace with this
|
| 27 |
+
tseslint.configs.recommendedTypeChecked,
|
| 28 |
+
// Alternatively, use this for stricter rules
|
| 29 |
+
tseslint.configs.strictTypeChecked,
|
| 30 |
+
// Optionally, add this for stylistic rules
|
| 31 |
+
tseslint.configs.stylisticTypeChecked,
|
| 32 |
+
|
| 33 |
+
// Other configs...
|
| 34 |
+
],
|
| 35 |
+
languageOptions: {
|
| 36 |
+
parserOptions: {
|
| 37 |
+
project: ['./tsconfig.node.json', './tsconfig.app.json'],
|
| 38 |
+
tsconfigRootDir: import.meta.dirname,
|
| 39 |
+
},
|
| 40 |
+
// other options...
|
| 41 |
+
},
|
| 42 |
+
},
|
| 43 |
+
])
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
You can also install [eslint-plugin-react-x](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-x) and [eslint-plugin-react-dom](https://github.com/Rel1cx/eslint-react/tree/main/packages/plugins/eslint-plugin-react-dom) for React-specific lint rules:
|
| 47 |
+
|
| 48 |
+
```js
|
| 49 |
+
// eslint.config.js
|
| 50 |
+
import reactX from 'eslint-plugin-react-x'
|
| 51 |
+
import reactDom from 'eslint-plugin-react-dom'
|
| 52 |
+
|
| 53 |
+
export default defineConfig([
|
| 54 |
+
globalIgnores(['dist']),
|
| 55 |
+
{
|
| 56 |
+
files: ['**/*.{ts,tsx}'],
|
| 57 |
+
extends: [
|
| 58 |
+
// Other configs...
|
| 59 |
+
// Enable lint rules for React
|
| 60 |
+
reactX.configs['recommended-typescript'],
|
| 61 |
+
// Enable lint rules for React DOM
|
| 62 |
+
reactDom.configs.recommended,
|
| 63 |
+
],
|
| 64 |
+
languageOptions: {
|
| 65 |
+
parserOptions: {
|
| 66 |
+
project: ['./tsconfig.node.json', './tsconfig.app.json'],
|
| 67 |
+
tsconfigRootDir: import.meta.dirname,
|
| 68 |
+
},
|
| 69 |
+
// other options...
|
| 70 |
+
},
|
| 71 |
+
},
|
| 72 |
+
])
|
| 73 |
+
```
|
dist/assets/index-BoCAEVud.css
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
:root{font-family:system-ui,Avenir,Helvetica,Arial,sans-serif;line-height:1.5;font-weight:400;color-scheme:light dark;color:#ffffffde;background-color:#242424;font-synthesis:none;text-rendering:optimizeLegibility;-webkit-font-smoothing:antialiased;-moz-osx-font-smoothing:grayscale}a{font-weight:500;color:#646cff;text-decoration:inherit}a:hover{color:#535bf2}body{margin:0;min-width:320px;min-height:100vh}h1{font-size:3.2em;line-height:1.1}button{border-radius:8px;border:1px solid transparent;padding:.6em 1.2em;font-size:1em;font-weight:500;font-family:inherit;background-color:#1a1a1a;cursor:pointer;transition:border-color .25s}button:hover{border-color:orange}button:focus,button:focus-visible{outline:4px auto -webkit-focus-ring-color}@media(prefers-color-scheme:light){:root{color:#213547;background-color:#fff}a:hover{color:#747bff}button{background-color:#f9f9f9}}h1{font-size:40px;font-weight:700;text-align:left}.main-container{flex:1;display:grid;grid-template-columns:2fr 1fr;gap:2rem}.convolution-viewer{display:flex;flex-direction:row;justify-content:center;gap:2rem;padding:2rem}.convolution-viewer-item{display:flex;flex-direction:column;align-items:center;width:300px;height:340px;border-radius:10px;background:#fafafa;padding:1rem;box-shadow:0 2px 6px #0000000d;transition:transform .2s ease,box-shadow .2s ease;border:1px solid black}.convolution-viewer-item h3{margin:.5rem 0 1rem;font-size:1.1rem;font-weight:600;color:#333;text-align:center;border-bottom:2px solid orange;padding-bottom:.3rem}.convolution-viewer-item img{flex-grow:1;width:100%;height:100%;max-height:none;object-fit:contain}` .app-button{padding:8px 16px;background-color:#f0f0f0;border:none;border-radius:8px;cursor:pointer;font-size:1rem}.app-button:focus{outline:none}.app-button:hover{outline:1px solid orange}.app-button.selected{background-color:orange;color:#fff}.preset-selector{display:flex;align-items:center;gap:12px;background:#f7f7f7;padding:12px 16px;border-radius:8px;border:1px solid #ddd;margin-bottom:12px;width:fit-content}.preset-label{font-weight:600;font-size:14px}.preset-dropdown{padding:6px 10px;font-size:14px;border-radius:6px;border:1px solid #ccc;background:#fff}.preset-dropdown:focus{outline:none}.preset-dropdown:focus-visible{outline:2px solid #4e9fff}.preset-apply{padding:6px 14px;font-size:14px;background:#4e9fff;color:#fff;border:none;border-radius:6px;cursor:pointer;transition:background .2s}.preset-apply:hover{background:#1e7be6}.kernel-editor{display:flex;flex-direction:column;gap:30px;padding:12px;max-width:fit-content;align-items:center}.kernel-matrix{display:flex;flex-direction:column;gap:8px}.kernel-row{display:flex;gap:8px}.kernel-size-options{border:1px solid #d0d0d0;border-radius:6px;padding:8px 12px;background:#f9f9f9;box-shadow:0 1px 2px #0000000a}.kernel-size-options form{display:flex;justify-content:space-between;align-items:center}.kernel-size-options input[type=number]{width:auto;max-width:4ch;margin:0 8px}.kernel-cell{width:4.5em;height:2.4em;text-align:center;font-size:15px;border:1px solid #d0d0d0;border-radius:6px;background:#fff;transition:border-color .2s,box-shadow .2s,transform .1s;box-shadow:0 1px 2px #0000000a}.kernel-cell:hover{border-color:#888}.kernel-cell:focus{outline:none;border-color:orange;box-shadow:0 0 0 3px #0078ff33;transform:scale(1.02)}.channel-buttons{display:flex;gap:12px}.channel-buttons .selected{background-color:orange;color:#fff}.channel-buttons button:focus{outline:none}.channel-buttons button:focus-visible{outline:2px solid orange}input[type=number]::-webkit-inner-spin-button,input[type=number]::-webkit-outer-spin-button{-webkit-appearance:none;margin:0}input[type=number]{-moz-appearance:textfield}.tabs{display:flex;flex-direction:column;height:100%}.tab-buttons{display:flex;border-bottom:1px solid #ccc;margin-bottom:20px}.tab-buttons button{padding:10px 20px;border:none;background:none;cursor:pointer;border-radius:0;will-change:border-bottom}.tab-buttons button.active{border-bottom:3px solid orange}.tab-content{flex:1;padding:20px;border:none}.tab-buttons button:focus{outline:none}.tab-buttons button:focus-visible{outline:2px solid orange}.switch-container{display:flex;align-items:center;justify-content:center;gap:12px;font-size:16px;-webkit-user-select:none;user-select:none}.switch{position:relative;width:70px;height:36px}.slider{position:absolute;cursor:pointer;inset:0;background:#777;border-radius:36px;transition:background .25s}.slider:before{position:absolute;content:"";height:28px;width:28px;left:4px;top:4px;background:#fff;border-radius:50%;transition:transform .25s}.switch input:checked+.slider{background:orange}.input-selector{display:flex;flex-direction:column;gap:30px}.switch{position:relative;display:inline-block;width:50px;height:24px}.switch input{opacity:0;width:0;height:0}.slider{position:absolute;cursor:pointer;inset:0;background-color:#ccc;transition:.4s;border-radius:24px}.slider:before{position:absolute;content:"";height:18px;width:18px;left:3px;bottom:3px;background-color:#fff;transition:.4s;border-radius:50%}input:checked+.slider{background-color:#4caf50}input:checked+.slider:before{transform:translate(26px)}.layer-viewer{padding:12px;border:1px solid #ccc;margin-bottom:20px;border-radius:6px;background:#fafafa}.layer-title{margin:0 0 6px;font-size:20px;font-weight:600}.layer-details{display:flex;justify-content:space-evenly;margin-bottom:12px;color:#444;font-size:14px}.layer-grid{display:flex;flex-wrap:wrap;gap:6px}.layer-grid img{width:100px;height:100px;image-rendering:pixelated}.card{background-color:#fff;border:1px solid #dcdfe6;border-radius:12px;padding:18px;display:flex;flex-direction:column;gap:24px;width:300px}.field{display:flex;flex-direction:column;gap:6px}.field label,.field legend{font-weight:600;color:#2c2f37;align-self:flex-start}.field input[type=text],.field input[type=number],.field textarea{border:1px solid #ccd0dd;border-radius:6px;padding:8px 10px;font-size:.95rem;font-family:inherit;background-color:#fafafe}.field textarea{resize:vertical}.field input[readonly],.field textarea[readonly]{background-color:#f4f6fc;color:#4a4f5c}.double-field{display:flex;gap:6px;width:100%;justify-content:space-between}.double-field .field{width:140px}.input-button{display:grid;grid-template-columns:1fr 1fr;gap:8px}.fieldset{border:1px solid #ccd0dd;border-radius:6px;padding:10px 12px 12px;display:flex;flex-direction:column;gap:6px}.fieldset label{font-weight:500;color:#343844}.slider-field{gap:10px}.slider-control{display:flex;align-items:center;gap:12px}.slider-control input[type=range]{width:80px;-webkit-appearance:none;background-color:#d3d3d3;border-radius:6px;accent-color:var(--accent-color);flex:1}.slider-control input[type=range]::-webkit-slider-runnable-track{height:6px;background:#d3d3d3;border-radius:6px}.slider-control input[type=range]::-webkit-slider-thumb{-webkit-appearance:none;width:14px;height:14px;border-radius:50%;background:var(--accent-color);cursor:pointer;margin-top:-4px}.slider-control span{min-width:28px;text-align:right;font-weight:600}.network-visualizer{display:grid;grid-template-columns:2fr 1fr;gap:20px}
|
|
|
|
|
|
dist/assets/index-CQRg13xj.css
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
@layer properties{@supports (((-webkit-hyphens:none)) and (not (margin-trim:inline))) or ((-moz-orient:inline) and (not (color:rgb(from red r g b)))){*,:before,:after,::backdrop{--tw-rotate-x:initial;--tw-rotate-y:initial;--tw-rotate-z:initial;--tw-skew-x:initial;--tw-skew-y:initial;--tw-space-y-reverse:0;--tw-border-style:solid;--tw-font-weight:initial;--tw-tracking:initial;--tw-ordinal:initial;--tw-slashed-zero:initial;--tw-numeric-figure:initial;--tw-numeric-spacing:initial;--tw-numeric-fraction:initial;--tw-shadow:0 0 #0000;--tw-shadow-color:initial;--tw-shadow-alpha:100%;--tw-inset-shadow:0 0 #0000;--tw-inset-shadow-color:initial;--tw-inset-shadow-alpha:100%;--tw-ring-color:initial;--tw-ring-shadow:0 0 #0000;--tw-inset-ring-color:initial;--tw-inset-ring-shadow:0 0 #0000;--tw-ring-inset:initial;--tw-ring-offset-width:0px;--tw-ring-offset-color:#fff;--tw-ring-offset-shadow:0 0 #0000;--tw-outline-style:solid;--tw-blur:initial;--tw-brightness:initial;--tw-contrast:initial;--tw-grayscale:initial;--tw-hue-rotate:initial;--tw-invert:initial;--tw-opacity:initial;--tw-saturate:initial;--tw-sepia:initial;--tw-drop-shadow:initial;--tw-drop-shadow-color:initial;--tw-drop-shadow-alpha:100%;--tw-drop-shadow-size:initial}}}@layer theme{:root,:host{--font-sans:ui-sans-serif, system-ui, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji";--font-mono:ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;--color-red-100:oklch(93.6% .032 17.717);--color-red-500:oklch(63.7% .237 25.331);--color-orange-200:oklch(90.1% .076 70.697);--color-orange-300:oklch(83.7% .128 66.29);--color-orange-400:oklch(75% .183 55.934);--color-orange-500:oklch(70.5% .213 47.604);--color-orange-700:oklch(55.3% .195 38.402);--color-lime-300:oklch(89.7% .196 126.665);--color-lime-500:oklch(76.8% .233 130.85);--color-blue-600:oklch(54.6% .245 262.881);--color-slate-50:oklch(98.4% .003 247.858);--color-slate-200:oklch(92.9% .013 255.508);--color-slate-700:oklch(37.2% .044 257.287);--color-slate-800:oklch(27.9% .041 260.031);--color-slate-900:oklch(20.8% .042 265.755);--color-gray-100:oklch(96.7% .003 264.542);--color-gray-200:oklch(92.8% .006 264.531);--color-gray-300:oklch(87.2% .01 258.338);--color-gray-700:oklch(37.3% .034 259.733);--color-gray-950:oklch(13% .028 261.692);--color-white:#fff;--spacing:.25rem;--container-sm:24rem;--text-sm:.875rem;--text-sm--line-height:calc(1.25 / .875);--text-lg:1.125rem;--text-lg--line-height:calc(1.75 / 1.125);--text-xl:1.25rem;--text-xl--line-height:calc(1.75 / 1.25);--font-weight-medium:500;--font-weight-semibold:600;--font-weight-bold:700;--tracking-tight:-.025em;--radius-md:.375rem;--animate-spin:spin 1s linear infinite;--default-transition-duration:.15s;--default-transition-timing-function:cubic-bezier(.4, 0, .2, 1);--default-font-family:var(--font-sans);--default-mono-font-family:var(--font-mono)}}@layer base{*,:after,:before,::backdrop{box-sizing:border-box;border:0 solid;margin:0;padding:0}::file-selector-button{box-sizing:border-box;border:0 solid;margin:0;padding:0}html,:host{-webkit-text-size-adjust:100%;tab-size:4;line-height:1.5;font-family:var(--default-font-family,ui-sans-serif, system-ui, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji");font-feature-settings:var(--default-font-feature-settings,normal);font-variation-settings:var(--default-font-variation-settings,normal);-webkit-tap-highlight-color:transparent}hr{height:0;color:inherit;border-top-width:1px}abbr:where([title]){-webkit-text-decoration:underline dotted;text-decoration:underline dotted}h1,h2,h3,h4,h5,h6{font-size:inherit;font-weight:inherit}a{color:inherit;-webkit-text-decoration:inherit;text-decoration:inherit}b,strong{font-weight:bolder}code,kbd,samp,pre{font-family:var(--default-mono-font-family,ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace);font-feature-settings:var(--default-mono-font-feature-settings,normal);font-variation-settings:var(--default-mono-font-variation-settings,normal);font-size:1em}small{font-size:80%}sub,sup{vertical-align:baseline;font-size:75%;line-height:0;position:relative}sub{bottom:-.25em}sup{top:-.5em}table{text-indent:0;border-color:inherit;border-collapse:collapse}:-moz-focusring{outline:auto}progress{vertical-align:baseline}summary{display:list-item}ol,ul,menu{list-style:none}img,svg,video,canvas,audio,iframe,embed,object{vertical-align:middle;display:block}img,video{max-width:100%;height:auto}button,input,select,optgroup,textarea{font:inherit;font-feature-settings:inherit;font-variation-settings:inherit;letter-spacing:inherit;color:inherit;opacity:1;background-color:#0000;border-radius:0}::file-selector-button{font:inherit;font-feature-settings:inherit;font-variation-settings:inherit;letter-spacing:inherit;color:inherit;opacity:1;background-color:#0000;border-radius:0}:where(select:is([multiple],[size])) optgroup{font-weight:bolder}:where(select:is([multiple],[size])) optgroup option{padding-inline-start:20px}::file-selector-button{margin-inline-end:4px}::placeholder{opacity:1}@supports (not ((-webkit-appearance:-apple-pay-button))) or (contain-intrinsic-size:1px){::placeholder{color:currentColor}@supports (color:color-mix(in lab,red,red)){::placeholder{color:color-mix(in oklab,currentcolor 50%,transparent)}}}textarea{resize:vertical}::-webkit-search-decoration{-webkit-appearance:none}::-webkit-date-and-time-value{min-height:1lh;text-align:inherit}::-webkit-datetime-edit{display:inline-flex}::-webkit-datetime-edit-fields-wrapper{padding:0}::-webkit-datetime-edit{padding-block:0}::-webkit-datetime-edit-year-field{padding-block:0}::-webkit-datetime-edit-month-field{padding-block:0}::-webkit-datetime-edit-day-field{padding-block:0}::-webkit-datetime-edit-hour-field{padding-block:0}::-webkit-datetime-edit-minute-field{padding-block:0}::-webkit-datetime-edit-second-field{padding-block:0}::-webkit-datetime-edit-millisecond-field{padding-block:0}::-webkit-datetime-edit-meridiem-field{padding-block:0}::-webkit-calendar-picker-indicator{line-height:1}:-moz-ui-invalid{box-shadow:none}button,input:where([type=button],[type=reset],[type=submit]){appearance:button}::file-selector-button{appearance:button}::-webkit-inner-spin-button{height:auto}::-webkit-outer-spin-button{height:auto}[hidden]:where(:not([hidden=until-found])){display:none!important}}@layer components;@layer utilities{.invisible{visibility:hidden}.visible{visibility:visible}.visible\!{visibility:visible!important}.absolute{position:absolute}.fixed{position:fixed}.relative{position:relative}.static{position:static}.sticky{position:sticky}.inset-0{inset:calc(var(--spacing) * 0)}.start{inset-inline-start:var(--spacing)}.start\!{inset-inline-start:var(--spacing)!important}.end{inset-inline-end:var(--spacing)}.end\!{inset-inline-end:var(--spacing)!important}.isolate{isolation:isolate}.z-50{z-index:50}.container{width:100%}@media(min-width:40rem){.container{max-width:40rem}}@media(min-width:48rem){.container{max-width:48rem}}@media(min-width:64rem){.container{max-width:64rem}}@media(min-width:80rem){.container{max-width:80rem}}@media(min-width:96rem){.container{max-width:96rem}}.m-125{margin:calc(var(--spacing) * 125)}.m-185{margin:calc(var(--spacing) * 185)}.m-214{margin:calc(var(--spacing) * 214)}.m-374{margin:calc(var(--spacing) * 374)}.m-571{margin:calc(var(--spacing) * 571)}.m-750{margin:calc(var(--spacing) * 750)}.m-812{margin:calc(var(--spacing) * 812)}.mt-2{margin-top:calc(var(--spacing) * 2)}.mt-3{margin-top:calc(var(--spacing) * 3)}.mb-2{margin-bottom:calc(var(--spacing) * 2)}.mb-3{margin-bottom:calc(var(--spacing) * 3)}.mb-4{margin-bottom:calc(var(--spacing) * 4)}.ml-2{margin-left:calc(var(--spacing) * 2)}.block{display:block}.flex{display:flex}.grid{display:grid}.hidden{display:none}.inline{display:inline}.inline-block{display:inline-block}.table{display:table}.h-16{height:calc(var(--spacing) * 16)}.h-18{height:calc(var(--spacing) * 18)}.h-24{height:calc(var(--spacing) * 24)}.h-dvh{height:100dvh}.h-full{height:100%}.max-h-64{max-height:calc(var(--spacing) * 64)}.max-h-full{max-height:100%}.min-h-0{min-height:calc(var(--spacing) * 0)}.w-6{width:calc(var(--spacing) * 6)}.w-14{width:calc(var(--spacing) * 14)}.w-16{width:calc(var(--spacing) * 16)}.w-24{width:calc(var(--spacing) * 24)}.w-full{width:100%}.max-w-full{max-width:100%}.max-w-sm{max-width:var(--container-sm)}.min-w-0{min-width:calc(var(--spacing) * 0)}.transform{transform:var(--tw-rotate-x,) var(--tw-rotate-y,) var(--tw-rotate-z,) var(--tw-skew-x,) var(--tw-skew-y,)}.animate-spin{animation:var(--animate-spin)}.cursor-col-resize{cursor:col-resize}.cursor-crosshair{cursor:crosshair}.cursor-default{cursor:default}.cursor-e-resize{cursor:e-resize}.cursor-ew-resize{cursor:ew-resize}.cursor-grab{cursor:grab}.cursor-move{cursor:move}.cursor-n-resize{cursor:n-resize}.cursor-ne-resize{cursor:ne-resize}.cursor-ns-resize{cursor:ns-resize}.cursor-nw-resize{cursor:nw-resize}.cursor-pointer{cursor:pointer}.cursor-row-resize{cursor:row-resize}.cursor-s-resize{cursor:s-resize}.cursor-se-resize{cursor:se-resize}.cursor-sw-resize{cursor:sw-resize}.cursor-w-resize{cursor:w-resize}.resize{resize:both}.resize-none{resize:none}.grid-cols-2{grid-template-columns:repeat(2,minmax(0,1fr))}.grid-cols-4{grid-template-columns:repeat(4,minmax(0,1fr))}.grid-cols-\[2fr_1fr\]{grid-template-columns:2fr 1fr}.grid-rows-2{grid-template-rows:repeat(2,minmax(0,1fr))}.flex-col{flex-direction:column}.flex-wrap{flex-wrap:wrap}.items-center{align-items:center}.items-end{align-items:flex-end}.justify-center{justify-content:center}.gap-1{gap:calc(var(--spacing) * 1)}.gap-2{gap:calc(var(--spacing) * 2)}.gap-3{gap:calc(var(--spacing) * 3)}.gap-4{gap:calc(var(--spacing) * 4)}:where(.space-y-4>:not(:last-child)){--tw-space-y-reverse:0;margin-block-start:calc(calc(var(--spacing) * 4) * var(--tw-space-y-reverse));margin-block-end:calc(calc(var(--spacing) * 4) * calc(1 - var(--tw-space-y-reverse)))}:where(.space-y-6>:not(:last-child)){--tw-space-y-reverse:0;margin-block-start:calc(calc(var(--spacing) * 6) * var(--tw-space-y-reverse));margin-block-end:calc(calc(var(--spacing) * 6) * calc(1 - var(--tw-space-y-reverse)))}.rounded{border-radius:.25rem}.rounded-full{border-radius:3.40282e38px}.rounded-md{border-radius:var(--radius-md)}.border{border-style:var(--tw-border-style);border-width:1px}.border-4{border-style:var(--tw-border-style);border-width:4px}.border-b-2{border-bottom-style:var(--tw-border-style);border-bottom-width:2px}.border-l{border-left-style:var(--tw-border-style);border-left-width:1px}.border-gray-200{border-color:var(--color-gray-200)}.border-gray-300{border-color:var(--color-gray-300)}.border-lime-500{border-color:var(--color-lime-500)}.border-orange-300{border-color:var(--color-orange-300)}.border-orange-400{border-color:var(--color-orange-400)}.border-orange-500{border-color:var(--color-orange-500)}.border-slate-200{border-color:var(--color-slate-200)}.border-t-blue-600{border-top-color:var(--color-blue-600)}.bg-gray-100{background-color:var(--color-gray-100)}.bg-orange-200{background-color:var(--color-orange-200)}.bg-red-100{background-color:var(--color-red-100)}.bg-slate-50{background-color:var(--color-slate-50)}.bg-white{background-color:var(--color-white)}.object-contain{object-fit:contain}.p-2{padding:calc(var(--spacing) * 2)}.p-3{padding:calc(var(--spacing) * 3)}.p-4{padding:calc(var(--spacing) * 4)}.px-2{padding-inline:calc(var(--spacing) * 2)}.px-3{padding-inline:calc(var(--spacing) * 3)}.px-5{padding-inline:calc(var(--spacing) * 5)}.py-1{padding-block:calc(var(--spacing) * 1)}.py-2{padding-block:calc(var(--spacing) * 2)}.text-center{text-align:center}.text-justify{text-align:justify}.text-lg{font-size:var(--text-lg);line-height:var(--tw-leading,var(--text-lg--line-height))}.text-sm{font-size:var(--text-sm);line-height:var(--tw-leading,var(--text-sm--line-height))}.text-xl{font-size:var(--text-xl);line-height:var(--tw-leading,var(--text-xl--line-height))}.font-bold{--tw-font-weight:var(--font-weight-bold);font-weight:var(--font-weight-bold)}.font-medium{--tw-font-weight:var(--font-weight-medium);font-weight:var(--font-weight-medium)}.font-semibold{--tw-font-weight:var(--font-weight-semibold);font-weight:var(--font-weight-semibold)}.tracking-tight{--tw-tracking:var(--tracking-tight);letter-spacing:var(--tracking-tight)}.text-gray-700{color:var(--color-gray-700)}.text-gray-950{color:var(--color-gray-950)}.text-orange-400{color:var(--color-orange-400)}.text-orange-500{color:var(--color-orange-500)}.text-orange-700{color:var(--color-orange-700)}.text-red-500{color:var(--color-red-500)}.text-slate-700{color:var(--color-slate-700)}.text-slate-800{color:var(--color-slate-800)}.text-slate-900{color:var(--color-slate-900)}.capitalize{text-transform:capitalize}.lowercase{text-transform:lowercase}.uppercase{text-transform:uppercase}.italic{font-style:italic}.ordinal{--tw-ordinal:ordinal;font-variant-numeric:var(--tw-ordinal,) var(--tw-slashed-zero,) var(--tw-numeric-figure,) var(--tw-numeric-spacing,) var(--tw-numeric-fraction,)}.line-through{text-decoration-line:line-through}.overline{text-decoration-line:overline}.underline{text-decoration-line:underline}.shadow,.shadow-sm{--tw-shadow:0 1px 3px 0 var(--tw-shadow-color,#0000001a), 0 1px 2px -1px var(--tw-shadow-color,#0000001a);box-shadow:var(--tw-inset-shadow),var(--tw-inset-ring-shadow),var(--tw-ring-offset-shadow),var(--tw-ring-shadow),var(--tw-shadow)}.ring{--tw-ring-shadow:var(--tw-ring-inset,) 0 0 0 calc(1px + var(--tw-ring-offset-width)) var(--tw-ring-color,currentcolor);box-shadow:var(--tw-inset-shadow),var(--tw-inset-ring-shadow),var(--tw-ring-offset-shadow),var(--tw-ring-shadow),var(--tw-shadow)}.ring-2{--tw-ring-shadow:var(--tw-ring-inset,) 0 0 0 calc(2px + var(--tw-ring-offset-width)) var(--tw-ring-color,currentcolor);box-shadow:var(--tw-inset-shadow),var(--tw-inset-ring-shadow),var(--tw-ring-offset-shadow),var(--tw-ring-shadow),var(--tw-shadow)}.ring-lime-300{--tw-ring-color:var(--color-lime-300)}.outline{outline-style:var(--tw-outline-style);outline-width:1px}.blur{--tw-blur:blur(8px);filter:var(--tw-blur,) var(--tw-brightness,) var(--tw-contrast,) var(--tw-grayscale,) var(--tw-hue-rotate,) var(--tw-invert,) var(--tw-saturate,) var(--tw-sepia,) var(--tw-drop-shadow,)}.grayscale{--tw-grayscale:grayscale(100%);filter:var(--tw-blur,) var(--tw-brightness,) var(--tw-contrast,) var(--tw-grayscale,) var(--tw-hue-rotate,) var(--tw-invert,) var(--tw-saturate,) var(--tw-sepia,) var(--tw-drop-shadow,)}.invert{--tw-invert:invert(100%);filter:var(--tw-blur,) var(--tw-brightness,) var(--tw-contrast,) var(--tw-grayscale,) var(--tw-hue-rotate,) var(--tw-invert,) var(--tw-saturate,) var(--tw-sepia,) var(--tw-drop-shadow,)}.filter{filter:var(--tw-blur,) var(--tw-brightness,) var(--tw-contrast,) var(--tw-grayscale,) var(--tw-hue-rotate,) var(--tw-invert,) var(--tw-saturate,) var(--tw-sepia,) var(--tw-drop-shadow,)}.transition{transition-property:color,background-color,border-color,outline-color,text-decoration-color,fill,stroke,--tw-gradient-from,--tw-gradient-via,--tw-gradient-to,opacity,box-shadow,transform,translate,scale,rotate,filter,-webkit-backdrop-filter,backdrop-filter,display,content-visibility,overlay,pointer-events;transition-timing-function:var(--tw-ease,var(--default-transition-timing-function));transition-duration:var(--tw-duration,var(--default-transition-duration))}@media(hover:hover){.hover\:bg-gray-200:hover{background-color:var(--color-gray-200)}.hover\:bg-orange-300:hover{background-color:var(--color-orange-300)}}@media(min-width:40rem){.sm\:grid-cols-3{grid-template-columns:repeat(3,minmax(0,1fr))}.sm\:grid-cols-6{grid-template-columns:repeat(6,minmax(0,1fr))}}@media(min-width:48rem){.md\:grid-cols-8{grid-template-columns:repeat(8,minmax(0,1fr))}}}@property --tw-rotate-x{syntax:"*";inherits:false}@property --tw-rotate-y{syntax:"*";inherits:false}@property --tw-rotate-z{syntax:"*";inherits:false}@property --tw-skew-x{syntax:"*";inherits:false}@property --tw-skew-y{syntax:"*";inherits:false}@property --tw-space-y-reverse{syntax:"*";inherits:false;initial-value:0}@property --tw-border-style{syntax:"*";inherits:false;initial-value:solid}@property --tw-font-weight{syntax:"*";inherits:false}@property --tw-tracking{syntax:"*";inherits:false}@property --tw-ordinal{syntax:"*";inherits:false}@property --tw-slashed-zero{syntax:"*";inherits:false}@property --tw-numeric-figure{syntax:"*";inherits:false}@property --tw-numeric-spacing{syntax:"*";inherits:false}@property --tw-numeric-fraction{syntax:"*";inherits:false}@property --tw-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@property --tw-shadow-color{syntax:"*";inherits:false}@property --tw-shadow-alpha{syntax:"<percentage>";inherits:false;initial-value:100%}@property --tw-inset-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@property --tw-inset-shadow-color{syntax:"*";inherits:false}@property --tw-inset-shadow-alpha{syntax:"<percentage>";inherits:false;initial-value:100%}@property --tw-ring-color{syntax:"*";inherits:false}@property --tw-ring-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@property --tw-inset-ring-color{syntax:"*";inherits:false}@property --tw-inset-ring-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@property --tw-ring-inset{syntax:"*";inherits:false}@property --tw-ring-offset-width{syntax:"<length>";inherits:false;initial-value:0}@property --tw-ring-offset-color{syntax:"*";inherits:false;initial-value:#fff}@property --tw-ring-offset-shadow{syntax:"*";inherits:false;initial-value:0 0 #0000}@property --tw-outline-style{syntax:"*";inherits:false;initial-value:solid}@property --tw-blur{syntax:"*";inherits:false}@property --tw-brightness{syntax:"*";inherits:false}@property --tw-contrast{syntax:"*";inherits:false}@property --tw-grayscale{syntax:"*";inherits:false}@property --tw-hue-rotate{syntax:"*";inherits:false}@property --tw-invert{syntax:"*";inherits:false}@property --tw-opacity{syntax:"*";inherits:false}@property --tw-saturate{syntax:"*";inherits:false}@property --tw-sepia{syntax:"*";inherits:false}@property --tw-drop-shadow{syntax:"*";inherits:false}@property --tw-drop-shadow-color{syntax:"*";inherits:false}@property --tw-drop-shadow-alpha{syntax:"<percentage>";inherits:false;initial-value:100%}@property --tw-drop-shadow-size{syntax:"*";inherits:false}@keyframes spin{to{transform:rotate(360deg)}}
|
dist/assets/{index-CARha6nB.js → index-p2vRWBSG.js}
RENAMED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dist/index.html
CHANGED
|
@@ -4,9 +4,9 @@
|
|
| 4 |
<meta charset="UTF-8" />
|
| 5 |
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 6 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
-
<title>
|
| 8 |
-
<script type="module" crossorigin src="/assets/index-
|
| 9 |
-
<link rel="stylesheet" crossorigin href="/assets/index-
|
| 10 |
</head>
|
| 11 |
<body>
|
| 12 |
<div id="root"></div>
|
|
|
|
| 4 |
<meta charset="UTF-8" />
|
| 5 |
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 6 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
+
<title>cnn_visualizer2</title>
|
| 8 |
+
<script type="module" crossorigin src="/assets/index-p2vRWBSG.js"></script>
|
| 9 |
+
<link rel="stylesheet" crossorigin href="/assets/index-CQRg13xj.css">
|
| 10 |
</head>
|
| 11 |
<body>
|
| 12 |
<div id="root"></div>
|
eslint.config.js
CHANGED
|
@@ -2,28 +2,22 @@ import js from '@eslint/js'
|
|
| 2 |
import globals from 'globals'
|
| 3 |
import reactHooks from 'eslint-plugin-react-hooks'
|
| 4 |
import reactRefresh from 'eslint-plugin-react-refresh'
|
|
|
|
| 5 |
import { defineConfig, globalIgnores } from 'eslint/config'
|
| 6 |
|
| 7 |
export default defineConfig([
|
| 8 |
globalIgnores(['dist']),
|
| 9 |
{
|
| 10 |
-
files: ['**/*.{
|
| 11 |
extends: [
|
| 12 |
js.configs.recommended,
|
|
|
|
| 13 |
reactHooks.configs.flat.recommended,
|
| 14 |
reactRefresh.configs.vite,
|
| 15 |
],
|
| 16 |
languageOptions: {
|
| 17 |
ecmaVersion: 2020,
|
| 18 |
globals: globals.browser,
|
| 19 |
-
parserOptions: {
|
| 20 |
-
ecmaVersion: 'latest',
|
| 21 |
-
ecmaFeatures: { jsx: true },
|
| 22 |
-
sourceType: 'module',
|
| 23 |
-
},
|
| 24 |
-
},
|
| 25 |
-
rules: {
|
| 26 |
-
'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
|
| 27 |
},
|
| 28 |
},
|
| 29 |
])
|
|
|
|
| 2 |
import globals from 'globals'
|
| 3 |
import reactHooks from 'eslint-plugin-react-hooks'
|
| 4 |
import reactRefresh from 'eslint-plugin-react-refresh'
|
| 5 |
+
import tseslint from 'typescript-eslint'
|
| 6 |
import { defineConfig, globalIgnores } from 'eslint/config'
|
| 7 |
|
| 8 |
export default defineConfig([
|
| 9 |
globalIgnores(['dist']),
|
| 10 |
{
|
| 11 |
+
files: ['**/*.{ts,tsx}'],
|
| 12 |
extends: [
|
| 13 |
js.configs.recommended,
|
| 14 |
+
tseslint.configs.recommended,
|
| 15 |
reactHooks.configs.flat.recommended,
|
| 16 |
reactRefresh.configs.vite,
|
| 17 |
],
|
| 18 |
languageOptions: {
|
| 19 |
ecmaVersion: 2020,
|
| 20 |
globals: globals.browser,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
},
|
| 22 |
},
|
| 23 |
])
|
index.html
CHANGED
|
@@ -4,10 +4,10 @@
|
|
| 4 |
<meta charset="UTF-8" />
|
| 5 |
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 6 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
-
<title>
|
| 8 |
</head>
|
| 9 |
<body>
|
| 10 |
<div id="root"></div>
|
| 11 |
-
<script type="module" src="/src/main.
|
| 12 |
</body>
|
| 13 |
</html>
|
|
|
|
| 4 |
<meta charset="UTF-8" />
|
| 5 |
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 6 |
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
+
<title>cnn_visualizer2</title>
|
| 8 |
</head>
|
| 9 |
<body>
|
| 10 |
<div id="root"></div>
|
| 11 |
+
<script type="module" src="/src/main.tsx"></script>
|
| 12 |
</body>
|
| 13 |
</html>
|
package-lock.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
package.json
CHANGED
|
@@ -1,22 +1,27 @@
|
|
| 1 |
{
|
| 2 |
-
"name": "
|
| 3 |
"private": true,
|
| 4 |
"version": "0.0.0",
|
| 5 |
"type": "module",
|
| 6 |
"scripts": {
|
| 7 |
"dev": "vite",
|
| 8 |
-
"build": "vite build",
|
| 9 |
"lint": "eslint .",
|
| 10 |
"preview": "vite preview"
|
| 11 |
},
|
| 12 |
"dependencies": {
|
|
|
|
| 13 |
"@tensorflow/tfjs": "^4.22.0",
|
|
|
|
|
|
|
| 14 |
"react": "^19.2.0",
|
| 15 |
"react-dom": "^19.2.0",
|
| 16 |
-
"react-plotly.js": "^2.6.0"
|
|
|
|
| 17 |
},
|
| 18 |
"devDependencies": {
|
| 19 |
"@eslint/js": "^9.39.1",
|
|
|
|
| 20 |
"@types/react": "^19.2.7",
|
| 21 |
"@types/react-dom": "^19.2.3",
|
| 22 |
"@vitejs/plugin-react": "^5.1.1",
|
|
@@ -24,6 +29,8 @@
|
|
| 24 |
"eslint-plugin-react-hooks": "^7.0.1",
|
| 25 |
"eslint-plugin-react-refresh": "^0.4.24",
|
| 26 |
"globals": "^16.5.0",
|
|
|
|
|
|
|
| 27 |
"vite": "^7.3.1"
|
| 28 |
}
|
| 29 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"name": "cnn_visualizer2",
|
| 3 |
"private": true,
|
| 4 |
"version": "0.0.0",
|
| 5 |
"type": "module",
|
| 6 |
"scripts": {
|
| 7 |
"dev": "vite",
|
| 8 |
+
"build": "tsc -b && vite build",
|
| 9 |
"lint": "eslint .",
|
| 10 |
"preview": "vite preview"
|
| 11 |
},
|
| 12 |
"dependencies": {
|
| 13 |
+
"@tailwindcss/vite": "^4.2.0",
|
| 14 |
"@tensorflow/tfjs": "^4.22.0",
|
| 15 |
+
"@tensorflow/tfjs-backend-wasm": "^4.22.0",
|
| 16 |
+
"plotly.js": "^3.4.0",
|
| 17 |
"react": "^19.2.0",
|
| 18 |
"react-dom": "^19.2.0",
|
| 19 |
+
"react-plotly.js": "^2.6.0",
|
| 20 |
+
"tailwindcss": "^4.2.0"
|
| 21 |
},
|
| 22 |
"devDependencies": {
|
| 23 |
"@eslint/js": "^9.39.1",
|
| 24 |
+
"@types/node": "^24.10.1",
|
| 25 |
"@types/react": "^19.2.7",
|
| 26 |
"@types/react-dom": "^19.2.3",
|
| 27 |
"@vitejs/plugin-react": "^5.1.1",
|
|
|
|
| 29 |
"eslint-plugin-react-hooks": "^7.0.1",
|
| 30 |
"eslint-plugin-react-refresh": "^0.4.24",
|
| 31 |
"globals": "^16.5.0",
|
| 32 |
+
"typescript": "~5.9.3",
|
| 33 |
+
"typescript-eslint": "^8.48.0",
|
| 34 |
"vite": "^7.3.1"
|
| 35 |
}
|
| 36 |
}
|
src/App.tsx
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useState } from 'react'
|
| 2 |
+
import Tabs from "./ui/Tabs.tsx"
|
| 3 |
+
import ConvolutionVisualizer from "./ConvolutionVisualizer.tsx"
|
| 4 |
+
import NetworkVisualizer from "./NetworkVisualizer.tsx"
|
| 5 |
+
|
| 6 |
+
function App() {
|
| 7 |
+
const tabs = ["Convolution", "Network"];
|
| 8 |
+
const [activeTab, setActiveTab] = useState<string>(tabs[0]);
|
| 9 |
+
|
| 10 |
+
return (
|
| 11 |
+
<>
|
| 12 |
+
<Tabs
|
| 13 |
+
tabs={tabs}
|
| 14 |
+
activeTab={activeTab}
|
| 15 |
+
onChange={setActiveTab}
|
| 16 |
+
/>
|
| 17 |
+
{ activeTab === "Convolution" && (
|
| 18 |
+
<ConvolutionVisualizer />
|
| 19 |
+
)}
|
| 20 |
+
{ activeTab === "Network" && (
|
| 21 |
+
<NetworkVisualizer />
|
| 22 |
+
)}
|
| 23 |
+
</>
|
| 24 |
+
);
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
export default App;
|
src/ConvolutionVisualizer.tsx
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useState, useRef } from "react";
|
| 2 |
+
import marioImage from "./assets/mario.png";
|
| 3 |
+
import Button from "./ui/Button.tsx";
|
| 4 |
+
import Radio from "./ui/Radio";
|
| 5 |
+
import useConvolutionProcessing from "./useConvolutionProcessing.ts";
|
| 6 |
+
import Dropdown from "./ui/Dropdown.tsx";
|
| 7 |
+
|
| 8 |
+
import {
|
| 9 |
+
DEFAULT_COLOR_KERNEL,
|
| 10 |
+
DEFAULT_GRAY_KERNEL,
|
| 11 |
+
GRAY_KERNEL_PRESETS,
|
| 12 |
+
COLOR_KERNEL_PRESETS,
|
| 13 |
+
} from "./kernels.ts";
|
| 14 |
+
|
| 15 |
+
const DEFAULT_IMAGE: string = marioImage;
|
| 16 |
+
const MIN_KERNEL_SIZE = 1;
|
| 17 |
+
const MAX_KERNEL_SIZE = 20;
|
| 18 |
+
|
| 19 |
+
export default function ConvolutionVisualizer() {
|
| 20 |
+
const [rawInputImage, setRawInputImage] = useState<string>(DEFAULT_IMAGE);
|
| 21 |
+
const [useColor, setUseColor] = useState<boolean>(true);
|
| 22 |
+
const [colorKernel, setColorKernel] = useState<number[][][]>(DEFAULT_COLOR_KERNEL);
|
| 23 |
+
const [grayscaleKernel, setGrayscaleKernel] = useState<number[][]>(DEFAULT_GRAY_KERNEL);
|
| 24 |
+
const kernel = useColor ? colorKernel : grayscaleKernel;
|
| 25 |
+
|
| 26 |
+
const [inputImage, outputImage] = useConvolutionProcessing(rawInputImage, kernel);
|
| 27 |
+
|
| 28 |
+
return (
|
| 29 |
+
<div className="grid grid-cols-2">
|
| 30 |
+
<InputOutputViewer
|
| 31 |
+
input={inputImage}
|
| 32 |
+
output={outputImage}
|
| 33 |
+
/>
|
| 34 |
+
<Sidebar
|
| 35 |
+
setImage={setRawInputImage}
|
| 36 |
+
useColor={useColor}
|
| 37 |
+
setUseColor={setUseColor}
|
| 38 |
+
colorKernel={colorKernel}
|
| 39 |
+
setColorKernel={setColorKernel}
|
| 40 |
+
grayscaleKernel={grayscaleKernel}
|
| 41 |
+
setGrayscaleKernel={setGrayscaleKernel}
|
| 42 |
+
/>
|
| 43 |
+
</div>
|
| 44 |
+
);
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
type InputOutputViewerProps = {
|
| 48 |
+
input: string | null;
|
| 49 |
+
output: string | null;
|
| 50 |
+
}
|
| 51 |
+
function InputOutputViewer({ input, output }: InputOutputViewerProps) {
|
| 52 |
+
return (
|
| 53 |
+
<div className="grid grid-rows-2">
|
| 54 |
+
<div className="flex flex-col items-center justify-center">
|
| 55 |
+
<h2 className="text-lg font-bold mb-2">Input Image</h2>
|
| 56 |
+
{input ? (
|
| 57 |
+
<img src={input} alt="Input" className="max-w-full max-h-full" />
|
| 58 |
+
) : (
|
| 59 |
+
<p>Loading...</p>
|
| 60 |
+
)}
|
| 61 |
+
</div>
|
| 62 |
+
<div className="flex flex-col items-center justify-center">
|
| 63 |
+
<h2 className="text-lg font-bold mb-2">Output Image</h2>
|
| 64 |
+
{output ? (
|
| 65 |
+
<img src={output} alt="Output" className="max-w-full max-h-full" />
|
| 66 |
+
) : (
|
| 67 |
+
<p>Processing...</p>
|
| 68 |
+
)}
|
| 69 |
+
</div>
|
| 70 |
+
</div>
|
| 71 |
+
);
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
type SidebarProps = {
|
| 75 |
+
setImage: (imageUrl: string) => void;
|
| 76 |
+
useColor: boolean;
|
| 77 |
+
setUseColor: (useColor: boolean) => void;
|
| 78 |
+
colorKernel: number[][][];
|
| 79 |
+
setColorKernel: (kernel: number[][][]) => void;
|
| 80 |
+
grayscaleKernel: number[][];
|
| 81 |
+
setGrayscaleKernel: (kernel: number[][]) => void;
|
| 82 |
+
}
|
| 83 |
+
function Sidebar({
|
| 84 |
+
setImage,
|
| 85 |
+
useColor,
|
| 86 |
+
setUseColor,
|
| 87 |
+
colorKernel,
|
| 88 |
+
setColorKernel,
|
| 89 |
+
grayscaleKernel,
|
| 90 |
+
setGrayscaleKernel,
|
| 91 |
+
}: SidebarProps) {
|
| 92 |
+
const imageFileInputRef = useRef<HTMLInputElement | null>(null);
|
| 93 |
+
|
| 94 |
+
return (
|
| 95 |
+
<div className="flex flex-col items-center border-l border-gray-200 p-4 gap-4">
|
| 96 |
+
<input
|
| 97 |
+
ref={imageFileInputRef}
|
| 98 |
+
type="file"
|
| 99 |
+
accept="image/*"
|
| 100 |
+
onChange={(e) => {
|
| 101 |
+
const file = e.target.files?.[0];
|
| 102 |
+
if (!file) return;
|
| 103 |
+
|
| 104 |
+
const reader = new FileReader();
|
| 105 |
+
reader.onload = () => {
|
| 106 |
+
const result = reader.result;
|
| 107 |
+
if (typeof result !== "string") {
|
| 108 |
+
return;
|
| 109 |
+
}
|
| 110 |
+
setImage(result);
|
| 111 |
+
};
|
| 112 |
+
reader.readAsDataURL(file);
|
| 113 |
+
}}
|
| 114 |
+
style={{ display: "none" }}
|
| 115 |
+
/>
|
| 116 |
+
<Button
|
| 117 |
+
label="Upload Image"
|
| 118 |
+
onClick={() => imageFileInputRef.current?.click()}
|
| 119 |
+
/>
|
| 120 |
+
<Radio
|
| 121 |
+
label="Color option"
|
| 122 |
+
options={["Grayscale", "Color"] as const}
|
| 123 |
+
activeOption={useColor ? "Color" : "Grayscale"}
|
| 124 |
+
onChange={(option) => setUseColor(option === "Color")}
|
| 125 |
+
/>
|
| 126 |
+
<KernelEditor
|
| 127 |
+
useColor={useColor}
|
| 128 |
+
colorKernel={colorKernel}
|
| 129 |
+
setColorKernel={setColorKernel}
|
| 130 |
+
grayscaleKernel={grayscaleKernel}
|
| 131 |
+
setGrayscaleKernel={setGrayscaleKernel}
|
| 132 |
+
/>
|
| 133 |
+
</div>
|
| 134 |
+
);
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
type KernelEditorProps = {
|
| 138 |
+
useColor: boolean;
|
| 139 |
+
colorKernel: number[][][];
|
| 140 |
+
setColorKernel: (kernel: number[][][]) => void;
|
| 141 |
+
grayscaleKernel: number[][];
|
| 142 |
+
setGrayscaleKernel: (kernel: number[][]) => void;
|
| 143 |
+
}
|
| 144 |
+
function KernelEditor({
|
| 145 |
+
useColor,
|
| 146 |
+
colorKernel,
|
| 147 |
+
setColorKernel,
|
| 148 |
+
grayscaleKernel,
|
| 149 |
+
setGrayscaleKernel,
|
| 150 |
+
}: KernelEditorProps) {
|
| 151 |
+
const colorPresetNames = Object.keys(COLOR_KERNEL_PRESETS) as string[];
|
| 152 |
+
const grayPresetNames = Object.keys(GRAY_KERNEL_PRESETS) as string[];
|
| 153 |
+
const [selectedChannel, setSelectedChannel] = useState<number>(0);
|
| 154 |
+
const [selectedColorPreset, setSelectedColorPreset] = useState<string>(colorPresetNames[0]);
|
| 155 |
+
const [selectedGrayPreset, setSelectedGrayPreset] = useState<string>(grayPresetNames[0]);
|
| 156 |
+
const [loadedColorPreset, setLoadedColorPreset] = useState<string | null>("Laplacian");
|
| 157 |
+
const [loadedGrayPreset, setLoadedGrayPreset] = useState<string | null>("Laplacian");
|
| 158 |
+
|
| 159 |
+
function colorKernelToRepr(nextKernel: number[][][]): string[][][] {
|
| 160 |
+
return nextKernel.map((channel) =>
|
| 161 |
+
channel.map((row) => row.map((value) => value.toString())),
|
| 162 |
+
);
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
function grayKernelToRepr(nextKernel: number[][]): string[][] {
|
| 166 |
+
return nextKernel.map((row) => row.map((value) => value.toString()));
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
const [draftColorKernel, setDraftColorKernel] = useState<string[][][]>(colorKernelToRepr(colorKernel));
|
| 170 |
+
const [draftGrayKernel, setDraftGrayKernel] = useState<string[][]>(grayKernelToRepr(grayscaleKernel));
|
| 171 |
+
|
| 172 |
+
useEffect(() => {
|
| 173 |
+
setDraftColorKernel(colorKernelToRepr(colorKernel));
|
| 174 |
+
}, [colorKernel]);
|
| 175 |
+
|
| 176 |
+
useEffect(() => {
|
| 177 |
+
setDraftGrayKernel(grayKernelToRepr(grayscaleKernel));
|
| 178 |
+
}, [grayscaleKernel]);
|
| 179 |
+
|
| 180 |
+
useEffect(() => {
|
| 181 |
+
setSelectedChannel(0);
|
| 182 |
+
}, [useColor]);
|
| 183 |
+
|
| 184 |
+
function parseCell(cell: string): number | null {
|
| 185 |
+
const trimmed = cell.trim();
|
| 186 |
+
if (trimmed === "") return null;
|
| 187 |
+
const parsed = Number(trimmed);
|
| 188 |
+
return Number.isFinite(parsed) ? parsed : null;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
function cloneColorKernel(nextKernel: number[][][]): number[][][] {
|
| 192 |
+
return nextKernel.map((channel) => channel.map((row) => [...row]));
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
function cloneGrayKernel(nextKernel: number[][]): number[][] {
|
| 196 |
+
return nextKernel.map((row) => [...row]);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
function handleLoadPreset() {
|
| 200 |
+
if (useColor) {
|
| 201 |
+
const presetKernel = cloneColorKernel(
|
| 202 |
+
COLOR_KERNEL_PRESETS[selectedColorPreset as keyof typeof COLOR_KERNEL_PRESETS],
|
| 203 |
+
);
|
| 204 |
+
setDraftColorKernel(colorKernelToRepr(presetKernel));
|
| 205 |
+
setColorKernel(presetKernel);
|
| 206 |
+
setLoadedColorPreset(selectedColorPreset);
|
| 207 |
+
setSelectedChannel(0);
|
| 208 |
+
return;
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
const presetKernel = cloneGrayKernel(
|
| 212 |
+
GRAY_KERNEL_PRESETS[selectedGrayPreset as keyof typeof GRAY_KERNEL_PRESETS],
|
| 213 |
+
);
|
| 214 |
+
setDraftGrayKernel(grayKernelToRepr(presetKernel));
|
| 215 |
+
setGrayscaleKernel(presetKernel);
|
| 216 |
+
setLoadedGrayPreset(selectedGrayPreset);
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
function kernelsEqual2D(a: number[][], b: number[][]): boolean {
|
| 220 |
+
if (a.length !== b.length) return false;
|
| 221 |
+
if ((a[0]?.length ?? 0) !== (b[0]?.length ?? 0)) return false;
|
| 222 |
+
return a.every((row, rowIndex) =>
|
| 223 |
+
row.every((value, colIndex) => value === b[rowIndex][colIndex]),
|
| 224 |
+
);
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
function kernelsEqual3D(a: number[][][], b: number[][][]): boolean {
|
| 228 |
+
if (a.length !== b.length) return false;
|
| 229 |
+
return a.every((channel, channelIndex) => kernelsEqual2D(channel, b[channelIndex]));
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
function handleKernelChange(channel: number, row: number, col: number, value: string) {
|
| 233 |
+
if (useColor) {
|
| 234 |
+
const nextDraft = draftColorKernel.map((c) => c.map((r) => [...r]));
|
| 235 |
+
nextDraft[channel][row][col] = value;
|
| 236 |
+
setDraftColorKernel(nextDraft);
|
| 237 |
+
|
| 238 |
+
const parsedKernel = nextDraft.map((c) => c.map((r) => r.map(parseCell)));
|
| 239 |
+
const isValid = parsedKernel.every((c) => c.every((r) => r.every((v) => v !== null)));
|
| 240 |
+
if (isValid) {
|
| 241 |
+
setColorKernel(parsedKernel as number[][][]);
|
| 242 |
+
}
|
| 243 |
+
return;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
const nextDraft = draftGrayKernel.map((r) => [...r]);
|
| 247 |
+
nextDraft[row][col] = value;
|
| 248 |
+
setDraftGrayKernel(nextDraft);
|
| 249 |
+
|
| 250 |
+
const parsedKernel = nextDraft.map((r) => r.map(parseCell));
|
| 251 |
+
const isValid = parsedKernel.every((r) => r.every((v) => v !== null));
|
| 252 |
+
if (isValid) {
|
| 253 |
+
setGrayscaleKernel(parsedKernel as number[][]);
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
function handleSizeChange(newWidth: number, newHeight: number) {
|
| 258 |
+
const clampedWidth = Math.min(MAX_KERNEL_SIZE, Math.max(MIN_KERNEL_SIZE, newWidth));
|
| 259 |
+
const clampedHeight = Math.min(MAX_KERNEL_SIZE, Math.max(MIN_KERNEL_SIZE, newHeight));
|
| 260 |
+
|
| 261 |
+
function resizeMatrix(matrix: string[][], width: number, height: number): string[][] {
|
| 262 |
+
return Array.from({ length: height }, (_, rowIndex) =>
|
| 263 |
+
Array.from({ length: width }, (_, colIndex) => matrix[rowIndex]?.[colIndex] ?? "0"),
|
| 264 |
+
);
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
if (useColor) {
|
| 268 |
+
const resizedDraft = draftColorKernel.map((channelMatrix) =>
|
| 269 |
+
resizeMatrix(channelMatrix, clampedWidth, clampedHeight),
|
| 270 |
+
);
|
| 271 |
+
setDraftColorKernel(resizedDraft);
|
| 272 |
+
setColorKernel(
|
| 273 |
+
resizedDraft.map((channelMatrix) =>
|
| 274 |
+
channelMatrix.map((row) => row.map((cell) => parseCell(cell) ?? 0)),
|
| 275 |
+
),
|
| 276 |
+
);
|
| 277 |
+
return;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
const resizedDraft = resizeMatrix(draftGrayKernel, clampedWidth, clampedHeight);
|
| 281 |
+
setDraftGrayKernel(resizedDraft);
|
| 282 |
+
setGrayscaleKernel(resizedDraft.map((row) => row.map((cell) => parseCell(cell) ?? 0)));
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
const channelLabels = ["R", "G", "B"];
|
| 286 |
+
const activeMatrix = useColor
|
| 287 |
+
? draftColorKernel[selectedChannel]
|
| 288 |
+
: draftGrayKernel;
|
| 289 |
+
const currentHeight = activeMatrix.length;
|
| 290 |
+
const currentWidth = activeMatrix[0]?.length ?? 0;
|
| 291 |
+
const isModified = useColor
|
| 292 |
+
? loadedColorPreset !== null
|
| 293 |
+
&& !kernelsEqual3D(
|
| 294 |
+
colorKernel,
|
| 295 |
+
COLOR_KERNEL_PRESETS[loadedColorPreset as keyof typeof COLOR_KERNEL_PRESETS],
|
| 296 |
+
)
|
| 297 |
+
: loadedGrayPreset !== null
|
| 298 |
+
&& !kernelsEqual2D(
|
| 299 |
+
grayscaleKernel,
|
| 300 |
+
GRAY_KERNEL_PRESETS[loadedGrayPreset as keyof typeof GRAY_KERNEL_PRESETS],
|
| 301 |
+
);
|
| 302 |
+
const loadedPresetLabel = useColor ? loadedColorPreset : loadedGrayPreset;
|
| 303 |
+
|
| 304 |
+
return (
|
| 305 |
+
<div className="w-full max-w-sm border border-gray-300 rounded p-3">
|
| 306 |
+
<h3 className="text-sm font-semibold mb-2">Kernel</h3>
|
| 307 |
+
<div className="flex items-end gap-2 mb-3">
|
| 308 |
+
<Dropdown
|
| 309 |
+
label="Preset"
|
| 310 |
+
options={useColor ? colorPresetNames : grayPresetNames}
|
| 311 |
+
activeOption={useColor ? selectedColorPreset : selectedGrayPreset}
|
| 312 |
+
onChange={(option) => {
|
| 313 |
+
if (useColor) {
|
| 314 |
+
setSelectedColorPreset(option);
|
| 315 |
+
return;
|
| 316 |
+
}
|
| 317 |
+
setSelectedGrayPreset(option);
|
| 318 |
+
}}
|
| 319 |
+
/>
|
| 320 |
+
<Button label="Load preset" onClick={handleLoadPreset} />
|
| 321 |
+
</div>
|
| 322 |
+
{loadedPresetLabel && (
|
| 323 |
+
<div className="text-sm mb-3">
|
| 324 |
+
<span>Loaded: {loadedPresetLabel}</span>
|
| 325 |
+
{isModified && <span className="text-orange-700 ml-2">(modified)</span>}
|
| 326 |
+
</div>
|
| 327 |
+
)}
|
| 328 |
+
{ useColor && (
|
| 329 |
+
<div className="flex gap-2 mb-3">
|
| 330 |
+
{channelLabels.map((label, index) => (
|
| 331 |
+
<button
|
| 332 |
+
key={label}
|
| 333 |
+
type="button"
|
| 334 |
+
className={`px-3 py-1 text-sm rounded border ${
|
| 335 |
+
selectedChannel === index ? "bg-orange-200 border-orange-300" : "bg-white border-gray-300"
|
| 336 |
+
}`}
|
| 337 |
+
onClick={() => setSelectedChannel(index)}
|
| 338 |
+
>
|
| 339 |
+
{label}
|
| 340 |
+
</button>
|
| 341 |
+
))}
|
| 342 |
+
</div>
|
| 343 |
+
)}
|
| 344 |
+
<div className="flex flex-col gap-4 mb-3">
|
| 345 |
+
<div className="flex items-center gap-2">
|
| 346 |
+
<span className="text-sm">Width</span>
|
| 347 |
+
<button
|
| 348 |
+
type="button"
|
| 349 |
+
className="px-2 py-1 text-sm rounded border border-gray-300"
|
| 350 |
+
onClick={() => handleSizeChange(currentWidth - 1, currentHeight)}
|
| 351 |
+
>
|
| 352 |
+
-
|
| 353 |
+
</button>
|
| 354 |
+
<span className="text-sm w-6 text-center">{currentWidth}</span>
|
| 355 |
+
<button
|
| 356 |
+
type="button"
|
| 357 |
+
className="px-2 py-1 text-sm rounded border border-gray-300"
|
| 358 |
+
onClick={() => handleSizeChange(currentWidth + 1, currentHeight)}
|
| 359 |
+
>
|
| 360 |
+
+
|
| 361 |
+
</button>
|
| 362 |
+
</div>
|
| 363 |
+
<div className="flex items-center gap-2">
|
| 364 |
+
<span className="text-sm">Height</span>
|
| 365 |
+
<button
|
| 366 |
+
type="button"
|
| 367 |
+
className="px-2 py-1 text-sm rounded border border-gray-300"
|
| 368 |
+
onClick={() => handleSizeChange(currentWidth, currentHeight - 1)}
|
| 369 |
+
>
|
| 370 |
+
-
|
| 371 |
+
</button>
|
| 372 |
+
<span className="text-sm w-6 text-center">{currentHeight}</span>
|
| 373 |
+
<button
|
| 374 |
+
type="button"
|
| 375 |
+
className="px-2 py-1 text-sm rounded border border-gray-300"
|
| 376 |
+
onClick={() => handleSizeChange(currentWidth, currentHeight + 1)}
|
| 377 |
+
>
|
| 378 |
+
+
|
| 379 |
+
</button>
|
| 380 |
+
</div>
|
| 381 |
+
</div>
|
| 382 |
+
<div className="flex flex-col gap-2">
|
| 383 |
+
{activeMatrix.map((row, rowIndex) => (
|
| 384 |
+
<div key={rowIndex} className="flex gap-2">
|
| 385 |
+
{row.map((cellValue, colIndex) => (
|
| 386 |
+
<input
|
| 387 |
+
key={`${rowIndex}-${colIndex}`}
|
| 388 |
+
type="text"
|
| 389 |
+
className="w-14 px-2 py-1 border border-gray-300 rounded text-sm"
|
| 390 |
+
value={cellValue}
|
| 391 |
+
onChange={(event) =>
|
| 392 |
+
handleKernelChange(
|
| 393 |
+
useColor ? selectedChannel : 0,
|
| 394 |
+
rowIndex,
|
| 395 |
+
colIndex,
|
| 396 |
+
event.target.value,
|
| 397 |
+
)
|
| 398 |
+
}
|
| 399 |
+
/>
|
| 400 |
+
))}
|
| 401 |
+
</div>
|
| 402 |
+
))}
|
| 403 |
+
</div>
|
| 404 |
+
</div>
|
| 405 |
+
);
|
| 406 |
+
}
|
src/InfoViewer.tsx
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useMemo, useState } from "react";
|
| 2 |
+
import type { RunInfo } from "./train.ts";
|
| 3 |
+
import Button from "./ui/Button.tsx";
|
| 4 |
+
|
| 5 |
+
type NumericArray = ArrayLike<number>;
|
| 6 |
+
type Shape4D = [number, number, number, number];
|
| 7 |
+
|
| 8 |
+
interface InputLayer {
|
| 9 |
+
type: "input";
|
| 10 |
+
output: NumericArray;
|
| 11 |
+
shape: Shape4D;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
interface Conv2dLayer {
|
| 15 |
+
type: "conv2d";
|
| 16 |
+
output: NumericArray;
|
| 17 |
+
kernels: NumericArray;
|
| 18 |
+
outputShape: Shape4D;
|
| 19 |
+
kernelShape: [number, number, number, number];
|
| 20 |
+
stride: number;
|
| 21 |
+
padding: number | "same" | "valid";
|
| 22 |
+
activationType?: string;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
interface MaxPoolLayer {
|
| 26 |
+
type: "maxpool";
|
| 27 |
+
output: NumericArray;
|
| 28 |
+
shape: Shape4D;
|
| 29 |
+
size: number;
|
| 30 |
+
stride: number;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
interface FlattenLayer {
|
| 34 |
+
type: "flatten";
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
interface DenseLayer {
|
| 38 |
+
type: "dense";
|
| 39 |
+
details?: string;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
interface OutputLayer {
|
| 43 |
+
type: "output";
|
| 44 |
+
output: NumericArray;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
type LayerInfo =
|
| 48 |
+
| InputLayer
|
| 49 |
+
| Conv2dLayer
|
| 50 |
+
| MaxPoolLayer
|
| 51 |
+
| FlattenLayer
|
| 52 |
+
| DenseLayer
|
| 53 |
+
| OutputLayer;
|
| 54 |
+
|
| 55 |
+
interface InfoViewerProps {
|
| 56 |
+
info?: RunInfo[];
|
| 57 |
+
onSampleIndexChange: () => void;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
interface InputViewerProps {
|
| 61 |
+
output: NumericArray;
|
| 62 |
+
shape: Shape4D;
|
| 63 |
+
onSampleIndexChange: () => void;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
interface Conv2dLayerViewerProps {
|
| 67 |
+
layerIdx: number;
|
| 68 |
+
stride: number;
|
| 69 |
+
padding: number | "same" | "valid";
|
| 70 |
+
activationType?: string;
|
| 71 |
+
kernels: NumericArray;
|
| 72 |
+
output: NumericArray;
|
| 73 |
+
kernelShape: [number, number, number, number];
|
| 74 |
+
outputShape: Shape4D;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
interface MaxPoolLayerViewerProps {
|
| 78 |
+
layerIdx: number;
|
| 79 |
+
stride: number;
|
| 80 |
+
size: number;
|
| 81 |
+
output: NumericArray;
|
| 82 |
+
shape: Shape4D;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
interface OutputLayerViewerProps {
|
| 86 |
+
probs: NumericArray;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
function asLayerInfo(layer: RunInfo): LayerInfo | null {
|
| 90 |
+
if (typeof layer.type !== "string") {
|
| 91 |
+
return null;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
switch (layer.type) {
|
| 95 |
+
case "input":
|
| 96 |
+
case "conv2d":
|
| 97 |
+
case "maxpool":
|
| 98 |
+
case "flatten":
|
| 99 |
+
case "dense":
|
| 100 |
+
case "output":
|
| 101 |
+
return layer as unknown as LayerInfo;
|
| 102 |
+
default:
|
| 103 |
+
return null;
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
function extractImage(output: NumericArray, h: number, w: number, cCount: number): ImageData {
|
| 108 |
+
const buffer = new Uint8ClampedArray(h * w * 4);
|
| 109 |
+
|
| 110 |
+
for (let i = 0; i < h; ++i) {
|
| 111 |
+
for (let j = 0; j < w; ++j) {
|
| 112 |
+
for (let c = 0; c < cCount; ++c) {
|
| 113 |
+
const val = output[i * w * cCount + j * cCount + c];
|
| 114 |
+
buffer[(i * w + j) * 4 + c] = val * 255;
|
| 115 |
+
}
|
| 116 |
+
for (let c = cCount; c < 3; ++c) {
|
| 117 |
+
buffer[(i * w + j) * 4 + c] = buffer[(i * w + j) * 4 + (cCount - 1)];
|
| 118 |
+
}
|
| 119 |
+
buffer[(i * w + j) * 4 + 3] = 255;
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
return new ImageData(buffer, w, h);
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
function extractKernels(
|
| 127 |
+
kernels: NumericArray,
|
| 128 |
+
selectedOutputChannel: number,
|
| 129 |
+
kh: number,
|
| 130 |
+
kw: number,
|
| 131 |
+
inC: number,
|
| 132 |
+
outC: number,
|
| 133 |
+
): ImageData[] {
|
| 134 |
+
let minVal = Infinity;
|
| 135 |
+
let maxVal = -Infinity;
|
| 136 |
+
|
| 137 |
+
for (let i = 0; i < kernels.length; ++i) {
|
| 138 |
+
const val = kernels[i];
|
| 139 |
+
if (val < minVal) minVal = val;
|
| 140 |
+
if (val > maxVal) maxVal = val;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
const kernelImgDatas: ImageData[] = [];
|
| 144 |
+
for (let ic = 0; ic < inC; ++ic) {
|
| 145 |
+
const buffer = new Uint8ClampedArray(kh * kw * 4);
|
| 146 |
+
for (let i = 0; i < kh; ++i) {
|
| 147 |
+
for (let j = 0; j < kw; ++j) {
|
| 148 |
+
const val = kernels[i * kw * inC * outC + j * inC * outC + ic * outC + selectedOutputChannel];
|
| 149 |
+
const normVal = (val - minVal) / (maxVal - minVal + 1e-8);
|
| 150 |
+
const pixel = Math.round(normVal * 255);
|
| 151 |
+
buffer[(i * kw + j) * 4 + 0] = pixel;
|
| 152 |
+
buffer[(i * kw + j) * 4 + 1] = pixel;
|
| 153 |
+
buffer[(i * kw + j) * 4 + 2] = pixel;
|
| 154 |
+
buffer[(i * kw + j) * 4 + 3] = 255;
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
kernelImgDatas.push(new ImageData(buffer, kw, kh));
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
return kernelImgDatas;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
function extractActivationMaps(activations: NumericArray, h: number, w: number, cCount: number): ImageData[] {
|
| 164 |
+
let minVal = Infinity;
|
| 165 |
+
let maxVal = -Infinity;
|
| 166 |
+
|
| 167 |
+
for (let i = 0; i < activations.length; ++i) {
|
| 168 |
+
const val = activations[i];
|
| 169 |
+
if (val < minVal) minVal = val;
|
| 170 |
+
if (val > maxVal) maxVal = val;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
const activationImgDatas: ImageData[] = [];
|
| 174 |
+
for (let c = 0; c < cCount; ++c) {
|
| 175 |
+
const buffer = new Uint8ClampedArray(h * w * 4);
|
| 176 |
+
for (let i = 0; i < h; ++i) {
|
| 177 |
+
for (let j = 0; j < w; ++j) {
|
| 178 |
+
const val = activations[i * w * cCount + j * cCount + c];
|
| 179 |
+
const normVal = (val - minVal) / (maxVal - minVal + 1e-8);
|
| 180 |
+
const pixel = Math.round(normVal * 255);
|
| 181 |
+
buffer[(i * w + j) * 4 + 0] = pixel;
|
| 182 |
+
buffer[(i * w + j) * 4 + 1] = pixel;
|
| 183 |
+
buffer[(i * w + j) * 4 + 2] = pixel;
|
| 184 |
+
buffer[(i * w + j) * 4 + 3] = 255;
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
activationImgDatas.push(new ImageData(buffer, w, h));
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
return activationImgDatas;
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
function imgDataToSrc(imgData: ImageData): string {
|
| 194 |
+
const canvas = document.createElement("canvas");
|
| 195 |
+
canvas.width = imgData.width;
|
| 196 |
+
canvas.height = imgData.height;
|
| 197 |
+
const ctx = canvas.getContext("2d");
|
| 198 |
+
if (!ctx) return "";
|
| 199 |
+
ctx.putImageData(imgData, 0, 0);
|
| 200 |
+
return canvas.toDataURL();
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
function InputViewer({ output, shape, onSampleIndexChange }: InputViewerProps) {
|
| 204 |
+
const [, h, w, cCount] = shape;
|
| 205 |
+
const imgData = useMemo(() => extractImage(output, h, w, cCount), [output, h, w, cCount]);
|
| 206 |
+
const imgSrc = useMemo(() => imgDataToSrc(imgData), [imgData]);
|
| 207 |
+
|
| 208 |
+
return (
|
| 209 |
+
<section className="rounded-md border border-slate-200 bg-white p-4 shadow-sm">
|
| 210 |
+
<h3 className="text-lg font-semibold text-slate-900">Input Layer</h3>
|
| 211 |
+
<div className="mt-2 text-sm text-slate-700">
|
| 212 |
+
<strong>Input size:</strong> {imgData.width} x {imgData.height}
|
| 213 |
+
</div>
|
| 214 |
+
<h4 className="mt-3 text-sm font-medium text-slate-800">Sample input</h4>
|
| 215 |
+
<div className="mt-2 flex flex-wrap items-center gap-3">
|
| 216 |
+
{imgSrc && (
|
| 217 |
+
<img
|
| 218 |
+
src={imgSrc}
|
| 219 |
+
alt="Input sample"
|
| 220 |
+
className="h-24 w-24 rounded border border-slate-200 object-contain"
|
| 221 |
+
/>
|
| 222 |
+
)}
|
| 223 |
+
<Button label="New Sample" onClick={onSampleIndexChange} />
|
| 224 |
+
</div>
|
| 225 |
+
</section>
|
| 226 |
+
);
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
function Conv2dLayerViewer({
|
| 230 |
+
layerIdx,
|
| 231 |
+
stride,
|
| 232 |
+
padding,
|
| 233 |
+
activationType,
|
| 234 |
+
kernels,
|
| 235 |
+
output,
|
| 236 |
+
kernelShape,
|
| 237 |
+
outputShape,
|
| 238 |
+
}: Conv2dLayerViewerProps) {
|
| 239 |
+
const [selectedChannel, setSelectedChannel] = useState<number | null>(null);
|
| 240 |
+
const [kh, kw, inC, outC] = kernelShape;
|
| 241 |
+
const [, h, w, cCount] = outputShape;
|
| 242 |
+
|
| 243 |
+
const kernelImgDatas = useMemo(() => {
|
| 244 |
+
if (selectedChannel == null) return null;
|
| 245 |
+
return extractKernels(kernels, selectedChannel, kh, kw, inC, outC);
|
| 246 |
+
}, [kernels, selectedChannel, kh, kw, inC, outC]);
|
| 247 |
+
|
| 248 |
+
const kernelSrcs = useMemo(() => kernelImgDatas?.map(imgDataToSrc) ?? null, [kernelImgDatas]);
|
| 249 |
+
|
| 250 |
+
const activations = useMemo(
|
| 251 |
+
() => extractActivationMaps(output, h, w, cCount),
|
| 252 |
+
[output, h, w, cCount],
|
| 253 |
+
);
|
| 254 |
+
const activationSrcs = useMemo(() => activations.map(imgDataToSrc), [activations]);
|
| 255 |
+
|
| 256 |
+
return (
|
| 257 |
+
<section className="rounded-md border border-slate-200 bg-white p-4 shadow-sm">
|
| 258 |
+
<h3 className="text-lg font-semibold text-slate-900">Convolution Layer</h3>
|
| 259 |
+
<div className="mt-2 grid grid-cols-2 gap-2 text-sm text-slate-700">
|
| 260 |
+
<div>
|
| 261 |
+
<strong>Kernel Size:</strong> {kh} x {kw}
|
| 262 |
+
</div>
|
| 263 |
+
<div>
|
| 264 |
+
<strong>Stride:</strong> {stride}
|
| 265 |
+
</div>
|
| 266 |
+
<div>
|
| 267 |
+
<strong>Padding:</strong> {padding}
|
| 268 |
+
</div>
|
| 269 |
+
<div>
|
| 270 |
+
<strong>Activation:</strong> {activationType ?? "none"}
|
| 271 |
+
</div>
|
| 272 |
+
<div>
|
| 273 |
+
<strong>Output channels:</strong> {cCount}
|
| 274 |
+
</div>
|
| 275 |
+
</div>
|
| 276 |
+
|
| 277 |
+
{selectedChannel != null && kernelSrcs && (
|
| 278 |
+
<>
|
| 279 |
+
<h4 className="mt-3 text-sm font-medium text-slate-800">
|
| 280 |
+
Kernel for output {selectedChannel} (min-max normalized)
|
| 281 |
+
</h4>
|
| 282 |
+
<div className="mt-2 grid grid-cols-4 gap-2 sm:grid-cols-6 md:grid-cols-8">
|
| 283 |
+
{kernelSrcs.map((src, idx) => (
|
| 284 |
+
<img
|
| 285 |
+
key={`${layerIdx}-${idx}-kernel`}
|
| 286 |
+
src={src}
|
| 287 |
+
alt={`Kernel ${idx}`}
|
| 288 |
+
className="h-24 w-24 rounded border border-slate-200 object-contain"
|
| 289 |
+
/>
|
| 290 |
+
))}
|
| 291 |
+
</div>
|
| 292 |
+
</>
|
| 293 |
+
)}
|
| 294 |
+
|
| 295 |
+
<h4 className="mt-3 text-sm font-medium text-slate-800">Activation Maps (min-max normalized)</h4>
|
| 296 |
+
<div className="mt-2 grid grid-cols-4 gap-2 sm:grid-cols-6 md:grid-cols-8">
|
| 297 |
+
{activationSrcs.map((src, idx) => (
|
| 298 |
+
<img
|
| 299 |
+
key={`${layerIdx}-${idx}-activation`}
|
| 300 |
+
src={src}
|
| 301 |
+
alt={`Activation Map ${idx}`}
|
| 302 |
+
onClick={() => setSelectedChannel(selectedChannel === idx ? null : idx)}
|
| 303 |
+
className={`h-24 w-24 cursor-pointer rounded border object-contain ${
|
| 304 |
+
selectedChannel === idx ? "border-lime-500 ring-2 ring-lime-300" : "border-slate-200"
|
| 305 |
+
}`}
|
| 306 |
+
/>
|
| 307 |
+
))}
|
| 308 |
+
</div>
|
| 309 |
+
</section>
|
| 310 |
+
);
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
function MaxPoolLayerViewer({ layerIdx, stride, size, output, shape }: MaxPoolLayerViewerProps) {
|
| 314 |
+
const [, h, w, cCount] = shape;
|
| 315 |
+
|
| 316 |
+
const activations = useMemo(() => extractActivationMaps(output, h, w, cCount), [output, h, w, cCount]);
|
| 317 |
+
const activationSrcs = useMemo(() => activations.map(imgDataToSrc), [activations]);
|
| 318 |
+
|
| 319 |
+
return (
|
| 320 |
+
<section className="rounded-md border border-slate-200 bg-white p-4 shadow-sm">
|
| 321 |
+
<h3 className="text-lg font-semibold text-slate-900">MaxPool Layer</h3>
|
| 322 |
+
<div className="mt-2 grid grid-cols-2 gap-2 text-sm text-slate-700">
|
| 323 |
+
<div>
|
| 324 |
+
<strong>Pool Size:</strong> {size} x {size}
|
| 325 |
+
</div>
|
| 326 |
+
<div>
|
| 327 |
+
<strong>Stride:</strong> {stride}
|
| 328 |
+
</div>
|
| 329 |
+
</div>
|
| 330 |
+
<h4 className="mt-3 text-sm font-medium text-slate-800">Outputs (min-max normalized)</h4>
|
| 331 |
+
<div className="mt-2 grid grid-cols-4 gap-2 sm:grid-cols-6 md:grid-cols-8">
|
| 332 |
+
{activationSrcs.map((src, idx) => (
|
| 333 |
+
<img
|
| 334 |
+
key={`${layerIdx}-${idx}-output`}
|
| 335 |
+
src={src}
|
| 336 |
+
alt={`Output ${idx}`}
|
| 337 |
+
className="h-24 w-24 rounded border border-slate-200 object-contain"
|
| 338 |
+
/>
|
| 339 |
+
))}
|
| 340 |
+
</div>
|
| 341 |
+
</section>
|
| 342 |
+
);
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
function OutputLayerViewer({ probs }: OutputLayerViewerProps) {
|
| 346 |
+
const numClasses = probs.length;
|
| 347 |
+
|
| 348 |
+
return (
|
| 349 |
+
<section className="rounded-md border border-slate-200 bg-white p-4 shadow-sm">
|
| 350 |
+
<h3 className="text-lg font-semibold text-slate-900">Output Layer (softmax)</h3>
|
| 351 |
+
<div className="mt-2 text-sm text-slate-700">
|
| 352 |
+
<strong>Number of Classes:</strong> {numClasses}
|
| 353 |
+
</div>
|
| 354 |
+
<h4 className="mt-3 text-sm font-medium text-slate-800">Class Probabilities</h4>
|
| 355 |
+
<div className="mt-2 grid grid-cols-2 gap-2 text-sm sm:grid-cols-3">
|
| 356 |
+
{Array.from({ length: numClasses }).map((_, i) => (
|
| 357 |
+
<div key={i} className="rounded border border-slate-200 bg-slate-50 px-2 py-1">
|
| 358 |
+
<strong>Class {i}:</strong> {Number(probs[i]).toFixed(2)}
|
| 359 |
+
</div>
|
| 360 |
+
))}
|
| 361 |
+
</div>
|
| 362 |
+
</section>
|
| 363 |
+
);
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
export default function InfoViewer({ info, onSampleIndexChange }: InfoViewerProps) {
|
| 367 |
+
const layers = useMemo(() => (info ?? []).map(asLayerInfo).filter((v): v is LayerInfo => v !== null), [info]);
|
| 368 |
+
|
| 369 |
+
function renderLayer(layer: LayerInfo, idx: number) {
|
| 370 |
+
switch (layer.type) {
|
| 371 |
+
case "input":
|
| 372 |
+
return (
|
| 373 |
+
<InputViewer
|
| 374 |
+
key={idx}
|
| 375 |
+
output={layer.output}
|
| 376 |
+
shape={layer.shape}
|
| 377 |
+
onSampleIndexChange={onSampleIndexChange}
|
| 378 |
+
/>
|
| 379 |
+
);
|
| 380 |
+
case "conv2d":
|
| 381 |
+
return (
|
| 382 |
+
<Conv2dLayerViewer
|
| 383 |
+
key={idx}
|
| 384 |
+
layerIdx={idx}
|
| 385 |
+
stride={layer.stride}
|
| 386 |
+
padding={layer.padding}
|
| 387 |
+
activationType={layer.activationType}
|
| 388 |
+
kernels={layer.kernels}
|
| 389 |
+
output={layer.output}
|
| 390 |
+
kernelShape={layer.kernelShape}
|
| 391 |
+
outputShape={layer.outputShape}
|
| 392 |
+
/>
|
| 393 |
+
);
|
| 394 |
+
case "maxpool":
|
| 395 |
+
return (
|
| 396 |
+
<MaxPoolLayerViewer
|
| 397 |
+
key={idx}
|
| 398 |
+
layerIdx={idx}
|
| 399 |
+
stride={layer.stride}
|
| 400 |
+
size={layer.size}
|
| 401 |
+
output={layer.output}
|
| 402 |
+
shape={layer.shape}
|
| 403 |
+
/>
|
| 404 |
+
);
|
| 405 |
+
case "flatten":
|
| 406 |
+
return null;
|
| 407 |
+
case "dense":
|
| 408 |
+
return (
|
| 409 |
+
<p key={idx} className="rounded-md border border-slate-200 bg-white p-4 text-sm text-slate-700 shadow-sm">
|
| 410 |
+
Dense Layer: {layer.details ?? "N/A"}
|
| 411 |
+
</p>
|
| 412 |
+
);
|
| 413 |
+
case "output":
|
| 414 |
+
return null;
|
| 415 |
+
default:
|
| 416 |
+
return (
|
| 417 |
+
<p key={idx} className="rounded-md border border-slate-200 bg-white p-4 text-sm text-slate-700 shadow-sm">
|
| 418 |
+
Unknown Layer Type
|
| 419 |
+
</p>
|
| 420 |
+
);
|
| 421 |
+
}
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
const lastLayer = layers.length > 0 ? layers[layers.length - 1] : null;
|
| 425 |
+
const outputLayer = lastLayer?.type === "output" ? lastLayer : null;
|
| 426 |
+
const bodyLayers = outputLayer ? layers.slice(0, -1) : layers;
|
| 427 |
+
|
| 428 |
+
return (
|
| 429 |
+
<div className="space-y-4 p-2">
|
| 430 |
+
{bodyLayers.map((layer, idx) => renderLayer(layer, idx))}
|
| 431 |
+
{outputLayer && <OutputLayerViewer probs={outputLayer.output} />}
|
| 432 |
+
</div>
|
| 433 |
+
);
|
| 434 |
+
}
|
src/NetworkVisualizer.tsx
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useState, useRef } from "react";
|
| 2 |
+
import * as tf from "@tensorflow/tfjs";
|
| 3 |
+
import Plot from "react-plotly.js";
|
| 4 |
+
|
| 5 |
+
import { MnistData } from "./mnist.js";
|
| 6 |
+
import { Cnn, train } from "./train.ts";
|
| 7 |
+
import type { TrainController, RunInfo, OptimizerParams } from "./train.ts";
|
| 8 |
+
import Button from "./ui/Button.tsx";
|
| 9 |
+
import InputField from "./ui/InputField.tsx";
|
| 10 |
+
import Tabs from "./ui/Tabs.tsx";
|
| 11 |
+
import Dropdown from "./ui/Dropdown.tsx";
|
| 12 |
+
import InfoViewer from "./InfoViewer.tsx";
|
| 13 |
+
|
| 14 |
+
const DEFAULT_ARCHITECTURE = `[conv2d filters=8 kernel=11
|
| 15 |
+
stride=1 padding=1 activation=relu]
|
| 16 |
+
|
| 17 |
+
[maxpool size=2 stride=2]
|
| 18 |
+
|
| 19 |
+
[flatten]
|
| 20 |
+
|
| 21 |
+
[dense units=10 activation=softmax]`;
|
| 22 |
+
|
| 23 |
+
const isFirefox = navigator.userAgent.toLowerCase().includes("firefox");
|
| 24 |
+
await tf.setBackend(isFirefox ? "cpu" : "webgl");
|
| 25 |
+
await tf.ready();
|
| 26 |
+
|
| 27 |
+
export default function NetworkVisualizer() {
|
| 28 |
+
const [dataset, setDataset] = useState<MnistData | null>(null);
|
| 29 |
+
|
| 30 |
+
useEffect(() => {
|
| 31 |
+
const loadData = async () => {
|
| 32 |
+
const data = new MnistData();
|
| 33 |
+
await data.load();
|
| 34 |
+
setDataset(data);
|
| 35 |
+
console.log("dataset loaded");
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
loadData();
|
| 39 |
+
}, []);
|
| 40 |
+
|
| 41 |
+
// architecture states
|
| 42 |
+
const [architecture, setArchitecture] = useState(DEFAULT_ARCHITECTURE);
|
| 43 |
+
const [optimizerType, setOptimizerType] = useState<string>('adam');
|
| 44 |
+
const [optimizerParams, setOptimizerParams] = useState<OptimizerParams>({
|
| 45 |
+
learningRate: '0.001',
|
| 46 |
+
beta1: '0.9',
|
| 47 |
+
beta2: '0.999',
|
| 48 |
+
epsilon: '1e-8',
|
| 49 |
+
batchSize: '32',
|
| 50 |
+
epochs: '5',
|
| 51 |
+
});
|
| 52 |
+
const modelRef = useRef<Cnn | null>(null);
|
| 53 |
+
const optimizerRef = useRef<tf.Optimizer | null>(null);
|
| 54 |
+
|
| 55 |
+
function handleArchitectureChange(newArchitecture: string) {
|
| 56 |
+
if (isTraining) {
|
| 57 |
+
alert('Cannot change architecture while training is in progress.');
|
| 58 |
+
} else {
|
| 59 |
+
setArchitecture(newArchitecture);
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
function handleOptimizerChange(newOptimizerType: string, newOptimizerParams: OptimizerParams) {
|
| 64 |
+
if (isTraining) {
|
| 65 |
+
alert('Cannot change optimizer settings while training is in progress.');
|
| 66 |
+
} else {
|
| 67 |
+
setOptimizerType(newOptimizerType);
|
| 68 |
+
setOptimizerParams(newOptimizerParams);
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
function handleSampleIndexChange() {
|
| 73 |
+
trainController.current.sampleIndex += 1;
|
| 74 |
+
updateTick();
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
function resetModel() {
|
| 79 |
+
if (!dataset) return;
|
| 80 |
+
|
| 81 |
+
if (modelRef.current) {
|
| 82 |
+
modelRef.current.dispose();
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
const cnn = new Cnn(architecture, dataset.numInputChannels);
|
| 86 |
+
modelRef.current = cnn;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
function resetOptimizer() {
|
| 90 |
+
if (optimizerType === 'adam') {
|
| 91 |
+
const learningRate = parseFloat(optimizerParams.learningRate);
|
| 92 |
+
const beta1 = parseFloat(optimizerParams.beta1 || "0.9");
|
| 93 |
+
const beta2 = parseFloat(optimizerParams.beta2 || "0.999");
|
| 94 |
+
const epsilon = parseFloat(optimizerParams.epsilon || "1e-8");
|
| 95 |
+
|
| 96 |
+
if (Number.isNaN(learningRate) || learningRate <= 0) {
|
| 97 |
+
alert('Invalid learning rate for Adam optimizer.');
|
| 98 |
+
return;
|
| 99 |
+
}
|
| 100 |
+
if (Number.isNaN(beta1) || beta1 < 0) {
|
| 101 |
+
alert('Invalid beta1 for Adam optimizer.');
|
| 102 |
+
return;
|
| 103 |
+
}
|
| 104 |
+
if (Number.isNaN(beta2) || beta2 < 0) {
|
| 105 |
+
alert('Invalid beta2 for Adam optimizer.');
|
| 106 |
+
return;
|
| 107 |
+
}
|
| 108 |
+
if (Number.isNaN(epsilon) || epsilon <= 0) {
|
| 109 |
+
alert('Invalid epsilon for Adam optimizer.');
|
| 110 |
+
return;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
const opt = tf.train.adam(learningRate, beta1, beta2, epsilon);
|
| 114 |
+
if (optimizerRef.current) {
|
| 115 |
+
optimizerRef.current.dispose();
|
| 116 |
+
}
|
| 117 |
+
optimizerRef.current = opt;
|
| 118 |
+
} else if (optimizerType === 'sgd') {
|
| 119 |
+
const learningRate = parseFloat(optimizerParams.learningRate);
|
| 120 |
+
if (Number.isNaN(learningRate) || learningRate <= 0) {
|
| 121 |
+
alert('Invalid learning rate for SGD optimizer.');
|
| 122 |
+
return;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
const opt = tf.train.sgd(learningRate);
|
| 126 |
+
if (optimizerRef.current) {
|
| 127 |
+
optimizerRef.current.dispose();
|
| 128 |
+
}
|
| 129 |
+
optimizerRef.current = opt;
|
| 130 |
+
} else {
|
| 131 |
+
alert(`Unsupported optimizer type: ${optimizerType}`);
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
// reset & init model and optimizer
|
| 136 |
+
useEffect(() => {
|
| 137 |
+
resetModel();
|
| 138 |
+
resetOptimizer();
|
| 139 |
+
}, [architecture, optimizerType, optimizerParams, dataset]);
|
| 140 |
+
|
| 141 |
+
// training states
|
| 142 |
+
const [isTraining, setIsTraining] = useState<boolean>(false);
|
| 143 |
+
const lossesRef = useRef<Array<number>>([]);
|
| 144 |
+
const trainController = useRef<TrainController>({
|
| 145 |
+
isPaused: false,
|
| 146 |
+
stopRequested: false,
|
| 147 |
+
sampleIndex: 0,
|
| 148 |
+
});
|
| 149 |
+
const infoRef = useRef<Array<RunInfo>>([]);
|
| 150 |
+
|
| 151 |
+
// render timing
|
| 152 |
+
const [, setTick] = useState<number>(0);
|
| 153 |
+
|
| 154 |
+
function updateTick() {
|
| 155 |
+
setTick((tick) => tick + 1);
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
async function startTraining() {
|
| 159 |
+
if (!modelRef || !dataset || !optimizerRef || isTraining) {
|
| 160 |
+
return;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
setIsTraining(true);
|
| 164 |
+
trainController.current.isPaused = false;
|
| 165 |
+
trainController.current.stopRequested = false;
|
| 166 |
+
|
| 167 |
+
const batchSize = parseFloat(optimizerParams.batchSize);
|
| 168 |
+
if (Number.isNaN(batchSize) || batchSize <= 0) {
|
| 169 |
+
alert('Invalid batch size.');
|
| 170 |
+
setIsTraining(false);
|
| 171 |
+
return;
|
| 172 |
+
}
|
| 173 |
+
const epochs = parseFloat(optimizerParams.epochs);
|
| 174 |
+
if (Number.isNaN(epochs) || epochs <= 0) {
|
| 175 |
+
alert('Invalid number of epochs.');
|
| 176 |
+
setIsTraining(false);
|
| 177 |
+
return;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
let lastTickUpdate = 0;
|
| 181 |
+
if (!modelRef.current) return;
|
| 182 |
+
if (!optimizerRef.current) return;
|
| 183 |
+
|
| 184 |
+
try {
|
| 185 |
+
await train(
|
| 186 |
+
dataset,
|
| 187 |
+
modelRef.current,
|
| 188 |
+
optimizerRef.current,
|
| 189 |
+
batchSize,
|
| 190 |
+
epochs,
|
| 191 |
+
trainController.current,
|
| 192 |
+
(_epoch, _batch, loss, info) => {
|
| 193 |
+
// lossesRef.current.push({ epoch, batch, loss });
|
| 194 |
+
lossesRef.current.push(loss);
|
| 195 |
+
console.log(loss);
|
| 196 |
+
|
| 197 |
+
infoRef.current = info;
|
| 198 |
+
|
| 199 |
+
// update tick every 50ms
|
| 200 |
+
const now = performance.now();
|
| 201 |
+
if (now - lastTickUpdate > 50) {
|
| 202 |
+
lastTickUpdate = now;
|
| 203 |
+
updateTick();
|
| 204 |
+
}
|
| 205 |
+
},
|
| 206 |
+
);
|
| 207 |
+
} finally {
|
| 208 |
+
setIsTraining(false);
|
| 209 |
+
trainController.current.isPaused = false;
|
| 210 |
+
trainController.current.stopRequested = false;
|
| 211 |
+
alert('Training finished.');
|
| 212 |
+
}
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
function handleStartTraining() {
|
| 216 |
+
console.log('Starting training...');
|
| 217 |
+
// trainController updated in startTraining
|
| 218 |
+
startTraining();
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
function handlePauseTraining() {
|
| 222 |
+
console.log('Pausing training...');
|
| 223 |
+
trainController.current.isPaused = true;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
function handleContinueTraining() {
|
| 227 |
+
console.log('Continuing training...');
|
| 228 |
+
trainController.current.isPaused = false;
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
function handleStopTraining() {
|
| 232 |
+
console.log('Stopping training...');
|
| 233 |
+
trainController.current.stopRequested = true;
|
| 234 |
+
trainController.current.isPaused = false;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
async function waitUntilNotTraining() {
|
| 238 |
+
return new Promise<void>((resolve) => {
|
| 239 |
+
function check() {
|
| 240 |
+
if (!isTraining) {
|
| 241 |
+
resolve();
|
| 242 |
+
} else {
|
| 243 |
+
requestAnimationFrame(check);
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
check();
|
| 247 |
+
});
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
async function handleResetTraining() {
|
| 251 |
+
console.log('Resetting training...');
|
| 252 |
+
handleStopTraining();
|
| 253 |
+
|
| 254 |
+
await waitUntilNotTraining();
|
| 255 |
+
console.log('Training stopped. Resetting model.');
|
| 256 |
+
lossesRef.current = [];
|
| 257 |
+
infoRef.current = [];
|
| 258 |
+
resetModel();
|
| 259 |
+
resetOptimizer();
|
| 260 |
+
|
| 261 |
+
updateTick();
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
return (
|
| 265 |
+
<div className="grid grid-cols-[2fr_1fr] h-dvh">
|
| 266 |
+
<TrainingViewer
|
| 267 |
+
isTraining={isTraining}
|
| 268 |
+
lossesRef={lossesRef}
|
| 269 |
+
infoRef={infoRef}
|
| 270 |
+
handleSampleIndexChange={handleSampleIndexChange}
|
| 271 |
+
/>
|
| 272 |
+
<Sidebar
|
| 273 |
+
architecture={architecture}
|
| 274 |
+
onArchitectureChange={handleArchitectureChange}
|
| 275 |
+
|
| 276 |
+
optimizerType={optimizerType}
|
| 277 |
+
optimizerParams={optimizerParams}
|
| 278 |
+
onOptimizerChange={handleOptimizerChange}
|
| 279 |
+
|
| 280 |
+
onStartTraining={handleStartTraining}
|
| 281 |
+
onPauseTraining={handlePauseTraining}
|
| 282 |
+
onContinueTraining={handleContinueTraining}
|
| 283 |
+
onStopTraining={handleStopTraining}
|
| 284 |
+
onResetTraining={handleResetTraining}
|
| 285 |
+
/>
|
| 286 |
+
</div>
|
| 287 |
+
);
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
interface TrainingViewerProps {
|
| 291 |
+
isTraining: boolean;
|
| 292 |
+
lossesRef: React.RefObject<Array<number>>;
|
| 293 |
+
infoRef: React.RefObject<Array<RunInfo>>;
|
| 294 |
+
handleSampleIndexChange: () => void;
|
| 295 |
+
}
|
| 296 |
+
function TrainingViewer({
|
| 297 |
+
isTraining,
|
| 298 |
+
lossesRef,
|
| 299 |
+
infoRef,
|
| 300 |
+
handleSampleIndexChange,
|
| 301 |
+
}: TrainingViewerProps) {
|
| 302 |
+
return (
|
| 303 |
+
<div>
|
| 304 |
+
<p>Training { isTraining ? "in progress" : "not in progress" }</p>
|
| 305 |
+
<Plot
|
| 306 |
+
data={[
|
| 307 |
+
{
|
| 308 |
+
x: lossesRef.current.map((_, i) => i),
|
| 309 |
+
y: lossesRef.current,
|
| 310 |
+
mode: 'lines',
|
| 311 |
+
type: 'scatter',
|
| 312 |
+
},
|
| 313 |
+
]}
|
| 314 |
+
layout={{
|
| 315 |
+
xaxis: { title: { text: 'Training steps' } },
|
| 316 |
+
yaxis: { title: { text: 'Train loss' } },
|
| 317 |
+
}}
|
| 318 |
+
/>
|
| 319 |
+
<InfoViewer info={infoRef.current} onSampleIndexChange={handleSampleIndexChange} />
|
| 320 |
+
</div>
|
| 321 |
+
)
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
interface SidebarProps {
|
| 325 |
+
architecture: string;
|
| 326 |
+
onArchitectureChange: (newArchitecture: string) => void;
|
| 327 |
+
|
| 328 |
+
optimizerType: string;
|
| 329 |
+
optimizerParams: OptimizerParams;
|
| 330 |
+
onOptimizerChange: (newOptimizerType: string, newOptimizerParams: OptimizerParams) => void;
|
| 331 |
+
|
| 332 |
+
onStartTraining: () => void;
|
| 333 |
+
onPauseTraining: () => void;
|
| 334 |
+
onContinueTraining: () => void;
|
| 335 |
+
onStopTraining: () => void;
|
| 336 |
+
onResetTraining: () => void;
|
| 337 |
+
}
|
| 338 |
+
function Sidebar({
|
| 339 |
+
architecture,
|
| 340 |
+
onArchitectureChange,
|
| 341 |
+
optimizerType,
|
| 342 |
+
optimizerParams,
|
| 343 |
+
onOptimizerChange,
|
| 344 |
+
onStartTraining,
|
| 345 |
+
onPauseTraining,
|
| 346 |
+
onContinueTraining,
|
| 347 |
+
onStopTraining,
|
| 348 |
+
onResetTraining,
|
| 349 |
+
}: SidebarProps) {
|
| 350 |
+
const tabs = ["Architecture", "Train"];
|
| 351 |
+
const [activeTab, setActiveTab] = useState<string>(tabs[0]);
|
| 352 |
+
|
| 353 |
+
const [architectureDraft, setArchitectureDraft] = useState<string>(architecture);
|
| 354 |
+
|
| 355 |
+
return (
|
| 356 |
+
<div className="bg-white flex flex-col h-full border-l border-gray-200 p-4 gap-4">
|
| 357 |
+
<Tabs tabs={tabs} activeTab={activeTab} onChange={setActiveTab} />
|
| 358 |
+
{ isFirefox && (
|
| 359 |
+
<p className="text-red-500">
|
| 360 |
+
Warning: This demo may be quite slow on Firefox.
|
| 361 |
+
</p>
|
| 362 |
+
)}
|
| 363 |
+
|
| 364 |
+
{ activeTab === "Architecture" && (
|
| 365 |
+
<>
|
| 366 |
+
<InputField
|
| 367 |
+
label="Architecture"
|
| 368 |
+
value={architectureDraft}
|
| 369 |
+
onChange={setArchitectureDraft}
|
| 370 |
+
rows={15}
|
| 371 |
+
/>
|
| 372 |
+
<Button
|
| 373 |
+
label="Apply architecture"
|
| 374 |
+
onClick={() => onArchitectureChange(architectureDraft)}
|
| 375 |
+
/>
|
| 376 |
+
</>
|
| 377 |
+
)}
|
| 378 |
+
|
| 379 |
+
{ activeTab === "Train" && (
|
| 380 |
+
<>
|
| 381 |
+
<Dropdown
|
| 382 |
+
label="Optimizer"
|
| 383 |
+
options={["sgd", "adam"]}
|
| 384 |
+
activeOption={optimizerType}
|
| 385 |
+
onChange={(newOptimizerType) => onOptimizerChange(newOptimizerType, optimizerParams)}
|
| 386 |
+
/>
|
| 387 |
+
{ optimizerType === 'sgd' && (
|
| 388 |
+
<InputField
|
| 389 |
+
label="Learning Rate"
|
| 390 |
+
value={optimizerParams.learningRate}
|
| 391 |
+
onChange={(newLearningRate) => onOptimizerChange(optimizerType, {...optimizerParams, learningRate: newLearningRate})}
|
| 392 |
+
/>
|
| 393 |
+
)}
|
| 394 |
+
{ optimizerType === 'adam' && (
|
| 395 |
+
<>
|
| 396 |
+
<InputField
|
| 397 |
+
label="Learning Rate"
|
| 398 |
+
value={optimizerParams.learningRate}
|
| 399 |
+
onChange={(newLearningRate) => onOptimizerChange(optimizerType, {...optimizerParams, learningRate: newLearningRate})}
|
| 400 |
+
/>
|
| 401 |
+
<InputField
|
| 402 |
+
label="Beta 1"
|
| 403 |
+
value={optimizerParams.beta1}
|
| 404 |
+
onChange={(newBeta1) => onOptimizerChange(optimizerType, {...optimizerParams, beta1: newBeta1})}
|
| 405 |
+
/>
|
| 406 |
+
<InputField
|
| 407 |
+
label="Beta 2"
|
| 408 |
+
value={optimizerParams.beta2}
|
| 409 |
+
onChange={(newBeta2) => onOptimizerChange(optimizerType, {...optimizerParams, beta2: newBeta2})}
|
| 410 |
+
/>
|
| 411 |
+
<InputField
|
| 412 |
+
label="Epsilon"
|
| 413 |
+
value={optimizerParams.epsilon}
|
| 414 |
+
onChange={(newEpsilon) => onOptimizerChange(optimizerType, {...optimizerParams, epsilon: newEpsilon})}
|
| 415 |
+
/>
|
| 416 |
+
</>
|
| 417 |
+
)}
|
| 418 |
+
<InputField
|
| 419 |
+
label="Batch Size"
|
| 420 |
+
value={optimizerParams.batchSize}
|
| 421 |
+
onChange={(newBatchSize) => onOptimizerChange(optimizerType, {...optimizerParams, batchSize: newBatchSize})}
|
| 422 |
+
/>
|
| 423 |
+
<InputField
|
| 424 |
+
label="Epochs"
|
| 425 |
+
value={optimizerParams.epochs}
|
| 426 |
+
onChange={(newEpochs) => onOptimizerChange(optimizerType, {...optimizerParams, epochs: newEpochs})}
|
| 427 |
+
/>
|
| 428 |
+
<Button
|
| 429 |
+
label="Start training"
|
| 430 |
+
onClick={onStartTraining}
|
| 431 |
+
/>
|
| 432 |
+
<Button
|
| 433 |
+
label="Pause training"
|
| 434 |
+
onClick={onPauseTraining}
|
| 435 |
+
/>
|
| 436 |
+
<Button
|
| 437 |
+
label="Continue training"
|
| 438 |
+
onClick={onContinueTraining}
|
| 439 |
+
/>
|
| 440 |
+
<Button
|
| 441 |
+
label="Stop training"
|
| 442 |
+
onClick={onStopTraining}
|
| 443 |
+
/>
|
| 444 |
+
<Button
|
| 445 |
+
label="Reset training"
|
| 446 |
+
onClick={onResetTraining}
|
| 447 |
+
/>
|
| 448 |
+
</>
|
| 449 |
+
)}
|
| 450 |
+
</div>
|
| 451 |
+
)
|
| 452 |
+
}
|
src/datasets.ts
ADDED
|
File without changes
|
src/index.css
CHANGED
|
@@ -1,66 +1 @@
|
|
| 1 |
-
|
| 2 |
-
font-family: system-ui, Avenir, Helvetica, Arial, sans-serif;
|
| 3 |
-
line-height: 1.5;
|
| 4 |
-
font-weight: 400;
|
| 5 |
-
|
| 6 |
-
color-scheme: light dark;
|
| 7 |
-
color: rgba(255, 255, 255, 0.87);
|
| 8 |
-
background-color: #242424;
|
| 9 |
-
|
| 10 |
-
font-synthesis: none;
|
| 11 |
-
text-rendering: optimizeLegibility;
|
| 12 |
-
-webkit-font-smoothing: antialiased;
|
| 13 |
-
-moz-osx-font-smoothing: grayscale;
|
| 14 |
-
}
|
| 15 |
-
|
| 16 |
-
a {
|
| 17 |
-
font-weight: 500;
|
| 18 |
-
color: #646cff;
|
| 19 |
-
text-decoration: inherit;
|
| 20 |
-
}
|
| 21 |
-
a:hover {
|
| 22 |
-
color: #535bf2;
|
| 23 |
-
}
|
| 24 |
-
|
| 25 |
-
body {
|
| 26 |
-
margin: 0;
|
| 27 |
-
min-width: 320px;
|
| 28 |
-
min-height: 100vh;
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
h1 {
|
| 32 |
-
font-size: 3.2em;
|
| 33 |
-
line-height: 1.1;
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
button {
|
| 37 |
-
border-radius: 8px;
|
| 38 |
-
border: 1px solid transparent;
|
| 39 |
-
padding: 0.6em 1.2em;
|
| 40 |
-
font-size: 1em;
|
| 41 |
-
font-weight: 500;
|
| 42 |
-
font-family: inherit;
|
| 43 |
-
background-color: #1a1a1a;
|
| 44 |
-
cursor: pointer;
|
| 45 |
-
transition: border-color 0.25s;
|
| 46 |
-
}
|
| 47 |
-
button:hover {
|
| 48 |
-
border-color: orange;
|
| 49 |
-
}
|
| 50 |
-
button:focus,
|
| 51 |
-
button:focus-visible {
|
| 52 |
-
outline: 4px auto -webkit-focus-ring-color;
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
@media (prefers-color-scheme: light) {
|
| 56 |
-
:root {
|
| 57 |
-
color: #213547;
|
| 58 |
-
background-color: #ffffff;
|
| 59 |
-
}
|
| 60 |
-
a:hover {
|
| 61 |
-
color: #747bff;
|
| 62 |
-
}
|
| 63 |
-
button {
|
| 64 |
-
background-color: #f9f9f9;
|
| 65 |
-
}
|
| 66 |
-
}
|
|
|
|
| 1 |
+
@import "tailwindcss";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/kernels.ts
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export const COLOR_KERNEL_PRESETS = {
|
| 2 |
+
Laplacian: [
|
| 3 |
+
[
|
| 4 |
+
[-1, -1, -1],
|
| 5 |
+
[-1, 8, -1],
|
| 6 |
+
[-1, -1, -1],
|
| 7 |
+
],
|
| 8 |
+
[
|
| 9 |
+
[-1, -1, -1],
|
| 10 |
+
[-1, 8, -1],
|
| 11 |
+
[-1, -1, -1],
|
| 12 |
+
],
|
| 13 |
+
[
|
| 14 |
+
[-1, -1, -1],
|
| 15 |
+
[-1, 8, -1],
|
| 16 |
+
[-1, -1, -1],
|
| 17 |
+
],
|
| 18 |
+
],
|
| 19 |
+
"Sobel X": [
|
| 20 |
+
[
|
| 21 |
+
[-1, 0, 1],
|
| 22 |
+
[-2, 0, 2],
|
| 23 |
+
[-1, 0, 1],
|
| 24 |
+
],
|
| 25 |
+
[
|
| 26 |
+
[-1, 0, 1],
|
| 27 |
+
[-2, 0, 2],
|
| 28 |
+
[-1, 0, 1],
|
| 29 |
+
],
|
| 30 |
+
[
|
| 31 |
+
[-1, 0, 1],
|
| 32 |
+
[-2, 0, 2],
|
| 33 |
+
[-1, 0, 1],
|
| 34 |
+
],
|
| 35 |
+
],
|
| 36 |
+
"Sobel Y": [
|
| 37 |
+
[
|
| 38 |
+
[-1, -2, -1],
|
| 39 |
+
[0, 0, 0],
|
| 40 |
+
[1, 2, 1],
|
| 41 |
+
],
|
| 42 |
+
[
|
| 43 |
+
[-1, -2, -1],
|
| 44 |
+
[0, 0, 0],
|
| 45 |
+
[1, 2, 1],
|
| 46 |
+
],
|
| 47 |
+
[
|
| 48 |
+
[-1, -2, -1],
|
| 49 |
+
[0, 0, 0],
|
| 50 |
+
[1, 2, 1],
|
| 51 |
+
],
|
| 52 |
+
],
|
| 53 |
+
"Box Blur": [
|
| 54 |
+
[
|
| 55 |
+
[0.04, 0.04, 0.04],
|
| 56 |
+
[0.04, 0.04, 0.04],
|
| 57 |
+
[0.04, 0.04, 0.04],
|
| 58 |
+
],
|
| 59 |
+
[
|
| 60 |
+
[0.04, 0.04, 0.04],
|
| 61 |
+
[0.04, 0.04, 0.04],
|
| 62 |
+
[0.04, 0.04, 0.04],
|
| 63 |
+
],
|
| 64 |
+
[
|
| 65 |
+
[0.04, 0.04, 0.04],
|
| 66 |
+
[0.04, 0.04, 0.04],
|
| 67 |
+
[0.04, 0.04, 0.04],
|
| 68 |
+
],
|
| 69 |
+
],
|
| 70 |
+
"AlexNet Conv1": [
|
| 71 |
+
[
|
| 72 |
+
[0.119, 0.094, 0.095, 0.105, 0.103, 0.067, 0.050, 0.050, 0.056, 0.022, 0.050],
|
| 73 |
+
[0.075, 0.039, 0.053, 0.076, 0.072, 0.073, 0.052, 0.027, 0.026, -0.011, 0.004],
|
| 74 |
+
[0.075, 0.039, 0.055, 0.056, 0.053, 0.050, 0.048, 0.025, 0.044, 0.010, 0.013],
|
| 75 |
+
[0.070, 0.053, 0.063, 0.062, 0.059, 0.039, 0.045, 0.038, 0.046, 0.002, 0.003],
|
| 76 |
+
[0.087, 0.075, 0.072, 0.083, 0.095, 0.065, 0.034, 0.021, 0.022, -0.011, -0.034],
|
| 77 |
+
[0.096, 0.099, 0.101, 0.109, 0.073, 0.036, -0.007, -0.043, -0.038, -0.057, -0.056],
|
| 78 |
+
[0.115, 0.115, 0.107, 0.091, 0.003, -0.090, -0.113, -0.139, -0.125, -0.084, -0.075],
|
| 79 |
+
[0.095, 0.110, 0.082, 0.042, -0.059, -0.159, -0.124, -0.158, -0.164, -0.115, -0.093],
|
| 80 |
+
[0.093, 0.104, 0.068, 0.024, -0.070, -0.184, -0.136, -0.185, -0.203, -0.128, -0.112],
|
| 81 |
+
[0.044, 0.065, 0.036, 0.005, -0.090, -0.194, -0.244, -0.244, -0.202, -0.114, -0.107],
|
| 82 |
+
[0.047, 0.063, 0.025, -0.020, -0.068, -0.117, -0.140, -0.163, -0.118, -0.096, -0.084],
|
| 83 |
+
],
|
| 84 |
+
[
|
| 85 |
+
[-0.073, -0.058, -0.081, -0.082, -0.068, -0.090, -0.071, -0.028, -0.001, -0.025, 0.025],
|
| 86 |
+
[-0.069, -0.068, -0.076, -0.047, -0.041, -0.065, -0.046, -0.032, -0.004, -0.030, 0.010],
|
| 87 |
+
[-0.100, -0.086, -0.105, -0.093, -0.086, -0.100, -0.086, -0.070, -0.027, -0.023, 0.007],
|
| 88 |
+
[-0.095, -0.078, -0.105, -0.111, -0.102, -0.124, -0.121, -0.086, -0.037, -0.037, -0.007],
|
| 89 |
+
[-0.102, -0.069, -0.123, -0.107, -0.099, -0.102, -0.109, -0.088, -0.055, -0.037, 0.000],
|
| 90 |
+
[-0.140, -0.095, -0.125, -0.121, -0.109, -0.086, -0.038, -0.050, -0.006, 0.023, 0.048],
|
| 91 |
+
[-0.189, -0.122, -0.157, -0.106, -0.099, -0.003, 0.137, 0.146, 0.116, 0.116, 0.092],
|
| 92 |
+
[-0.181, -0.134, -0.155, -0.090, 0.002, 0.174, 0.471, 0.449, 0.282, 0.189, 0.108],
|
| 93 |
+
[-0.151, -0.089, -0.097, -0.027, 0.101, 0.298, 0.572, 0.493, 0.309, 0.181, 0.084],
|
| 94 |
+
[-0.143, -0.076, -0.072, -0.012, 0.070, 0.165, 0.277, 0.253, 0.204, 0.164, 0.095],
|
| 95 |
+
[-0.086, -0.040, -0.051, -0.019, 0.035, 0.125, 0.170, 0.176, 0.164, 0.148, 0.102],
|
| 96 |
+
],
|
| 97 |
+
[
|
| 98 |
+
[-0.024, -0.002, -0.028, -0.008, -0.030, -0.060, -0.050, -0.001, 0.040, -0.007, 0.032],
|
| 99 |
+
[0.000, 0.022, 0.009, 0.048, 0.024, 0.001, 0.017, 0.024, 0.019, -0.014, 0.018],
|
| 100 |
+
[0.005, 0.029, 0.000, 0.001, -0.005, -0.035, -0.022, -0.028, 0.012, -0.003, 0.008],
|
| 101 |
+
[0.018, 0.047, 0.007, 0.001, -0.025, -0.043, -0.060, -0.053, -0.017, -0.024, -0.001],
|
| 102 |
+
[0.026, 0.050, -0.007, -0.013, -0.030, -0.044, -0.089, -0.098, -0.055, -0.044, -0.007],
|
| 103 |
+
[-0.020, 0.024, -0.024, -0.038, -0.077, -0.076, -0.061, -0.060, -0.047, -0.011, 0.025],
|
| 104 |
+
[-0.082, -0.014, -0.070, -0.069, -0.114, -0.054, 0.057, 0.069, 0.039, 0.043, 0.049],
|
| 105 |
+
[-0.096, -0.030, -0.110, -0.088, -0.057, 0.077, 0.355, 0.329, 0.143, 0.061, -0.021],
|
| 106 |
+
[-0.063, -0.012, -0.062, -0.029, 0.032, 0.165, 0.393, 0.312, 0.103, -0.009, -0.080],
|
| 107 |
+
[-0.046, 0.003, -0.040, -0.018, -0.028, 0.010, 0.072, 0.038, -0.026, -0.033, -0.076],
|
| 108 |
+
[-0.019, 0.011, -0.040, -0.042, -0.056, -0.034, -0.031, -0.076, -0.069, -0.041, -0.055],
|
| 109 |
+
],
|
| 110 |
+
],
|
| 111 |
+
Red: [
|
| 112 |
+
[
|
| 113 |
+
[0, 0, 0],
|
| 114 |
+
[0, 1, 0],
|
| 115 |
+
[0, 0, 0],
|
| 116 |
+
],
|
| 117 |
+
[
|
| 118 |
+
[0, 0, 0],
|
| 119 |
+
[0, 0, 0],
|
| 120 |
+
[0, 0, 0],
|
| 121 |
+
],
|
| 122 |
+
[
|
| 123 |
+
[0, 0, 0],
|
| 124 |
+
[0, 0, 0],
|
| 125 |
+
[0, 0, 0],
|
| 126 |
+
],
|
| 127 |
+
],
|
| 128 |
+
Green: [
|
| 129 |
+
[
|
| 130 |
+
[0, 0, 0],
|
| 131 |
+
[0, 0, 0],
|
| 132 |
+
[0, 0, 0],
|
| 133 |
+
],
|
| 134 |
+
[
|
| 135 |
+
[0, 0, 0],
|
| 136 |
+
[0, 1, 0],
|
| 137 |
+
[0, 0, 0],
|
| 138 |
+
],
|
| 139 |
+
[
|
| 140 |
+
[0, 0, 0],
|
| 141 |
+
[0, 0, 0],
|
| 142 |
+
[0, 0, 0],
|
| 143 |
+
],
|
| 144 |
+
],
|
| 145 |
+
Blue: [
|
| 146 |
+
[
|
| 147 |
+
[0, 0, 0],
|
| 148 |
+
[0, 0, 0],
|
| 149 |
+
[0, 0, 0],
|
| 150 |
+
],
|
| 151 |
+
[
|
| 152 |
+
[0, 0, 0],
|
| 153 |
+
[0, 0, 0],
|
| 154 |
+
[0, 0, 0],
|
| 155 |
+
],
|
| 156 |
+
[
|
| 157 |
+
[0, 0, 0],
|
| 158 |
+
[0, 1, 0],
|
| 159 |
+
[0, 0, 0],
|
| 160 |
+
],
|
| 161 |
+
]
|
| 162 |
+
};
|
| 163 |
+
|
| 164 |
+
export const GRAY_KERNEL_PRESETS = {
|
| 165 |
+
Laplacian: [
|
| 166 |
+
[-1, -1, -1],
|
| 167 |
+
[-1, 8, -1],
|
| 168 |
+
[-1, -1, -1],
|
| 169 |
+
],
|
| 170 |
+
'Sobel X': [
|
| 171 |
+
[-1, 0, 1],
|
| 172 |
+
[-2, 0, 2],
|
| 173 |
+
[-1, 0, 1],
|
| 174 |
+
],
|
| 175 |
+
'Sobel Y': [
|
| 176 |
+
[-1, -2, -1],
|
| 177 |
+
[0, 0, 0],
|
| 178 |
+
[1, 2, 1],
|
| 179 |
+
],
|
| 180 |
+
Sharpen: [
|
| 181 |
+
[0, -1, 0],
|
| 182 |
+
[-1, 5, -1],
|
| 183 |
+
[0, -1, 0],
|
| 184 |
+
],
|
| 185 |
+
'Box Blur': [
|
| 186 |
+
[0.11, 0.11, 0.11],
|
| 187 |
+
[0.11, 0.11, 0.11],
|
| 188 |
+
[0.11, 0.11, 0.11],
|
| 189 |
+
],
|
| 190 |
+
};
|
| 191 |
+
|
| 192 |
+
export const DEFAULT_COLOR_KERNEL = COLOR_KERNEL_PRESETS['Laplacian'];
|
| 193 |
+
export const DEFAULT_GRAY_KERNEL = GRAY_KERNEL_PRESETS['Laplacian'];
|
src/main.tsx
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { StrictMode } from 'react'
|
| 2 |
+
import { createRoot } from 'react-dom/client'
|
| 3 |
+
import './index.css'
|
| 4 |
+
import App from './App.tsx'
|
| 5 |
+
|
| 6 |
+
createRoot(document.getElementById('root')!).render(
|
| 7 |
+
<StrictMode>
|
| 8 |
+
<App />
|
| 9 |
+
</StrictMode>,
|
| 10 |
+
)
|
src/mnist.d.ts
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import * as tf from "@tensorflow/tfjs";
|
| 2 |
+
|
| 3 |
+
export interface BatchData {
|
| 4 |
+
xs: tf.Tensor2D;
|
| 5 |
+
labels: tf.Tensor2D;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
export interface TestSample {
|
| 9 |
+
xs: tf.Tensor2D;
|
| 10 |
+
labels: tf.Tensor2D;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
export class MnistData {
|
| 14 |
+
shuffledTrainIndex: number;
|
| 15 |
+
shuffledTestIndex: number;
|
| 16 |
+
numClasses: number;
|
| 17 |
+
numInputChannels: number;
|
| 18 |
+
trainSize: number;
|
| 19 |
+
testSize: number;
|
| 20 |
+
imageSize: number;
|
| 21 |
+
|
| 22 |
+
datasetImages: Float32Array;
|
| 23 |
+
datasetLabels: Uint8Array;
|
| 24 |
+
trainIndices: Uint32Array;
|
| 25 |
+
testIndices: Uint32Array;
|
| 26 |
+
trainImages: Float32Array;
|
| 27 |
+
testImages: Float32Array;
|
| 28 |
+
trainLabels: Uint8Array;
|
| 29 |
+
testLabels: Uint8Array;
|
| 30 |
+
|
| 31 |
+
load(): Promise<void>;
|
| 32 |
+
nextTrainBatch(batchSize: number): BatchData;
|
| 33 |
+
nextTestBatch(batchSize: number): BatchData;
|
| 34 |
+
nextBatch(
|
| 35 |
+
batchSize: number,
|
| 36 |
+
data: [Float32Array, Uint8Array],
|
| 37 |
+
index: () => number,
|
| 38 |
+
): BatchData;
|
| 39 |
+
getTestSample(index: number): TestSample;
|
| 40 |
+
}
|
src/mnist.js
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @license
|
| 3 |
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
* =============================================================================
|
| 16 |
+
*/
|
| 17 |
+
|
| 18 |
+
import * as tf from '@tensorflow/tfjs';
|
| 19 |
+
|
| 20 |
+
const IMAGE_SIZE = 784;
|
| 21 |
+
const NUM_CLASSES = 10;
|
| 22 |
+
const NUM_DATASET_ELEMENTS = 65000;
|
| 23 |
+
|
| 24 |
+
const TRAIN_TEST_RATIO = 5 / 6;
|
| 25 |
+
|
| 26 |
+
const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
|
| 27 |
+
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
|
| 28 |
+
|
| 29 |
+
const MNIST_IMAGES_SPRITE_PATH =
|
| 30 |
+
'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
|
| 31 |
+
const MNIST_LABELS_PATH =
|
| 32 |
+
'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';
|
| 33 |
+
|
| 34 |
+
/**
|
| 35 |
+
* A class that fetches the sprited MNIST dataset and returns shuffled batches.
|
| 36 |
+
*
|
| 37 |
+
* NOTE: This will get much easier. For now, we do data fetching and
|
| 38 |
+
* manipulation manually.
|
| 39 |
+
*/
|
| 40 |
+
export class MnistData {
|
| 41 |
+
constructor() {
|
| 42 |
+
this.shuffledTrainIndex = 0;
|
| 43 |
+
this.shuffledTestIndex = 0;
|
| 44 |
+
|
| 45 |
+
this.numClasses = NUM_CLASSES;
|
| 46 |
+
this.numInputChannels = 1;
|
| 47 |
+
this.trainSize = NUM_TRAIN_ELEMENTS;
|
| 48 |
+
this.testSize = NUM_TEST_ELEMENTS;
|
| 49 |
+
this.imageSize = 28;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
async load() {
|
| 53 |
+
// Make a request for the MNIST sprited image.
|
| 54 |
+
const img = new Image();
|
| 55 |
+
const canvas = document.createElement('canvas');
|
| 56 |
+
const ctx = canvas.getContext('2d');
|
| 57 |
+
const imgRequest = new Promise((resolve, reject) => {
|
| 58 |
+
img.crossOrigin = '';
|
| 59 |
+
img.onload = () => {
|
| 60 |
+
img.width = img.naturalWidth;
|
| 61 |
+
img.height = img.naturalHeight;
|
| 62 |
+
|
| 63 |
+
const datasetBytesBuffer = new ArrayBuffer(
|
| 64 |
+
NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4,
|
| 65 |
+
);
|
| 66 |
+
|
| 67 |
+
const chunkSize = 5000;
|
| 68 |
+
canvas.width = img.width;
|
| 69 |
+
canvas.height = chunkSize;
|
| 70 |
+
|
| 71 |
+
for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
|
| 72 |
+
const datasetBytesView = new Float32Array(
|
| 73 |
+
datasetBytesBuffer,
|
| 74 |
+
i * IMAGE_SIZE * chunkSize * 4,
|
| 75 |
+
IMAGE_SIZE * chunkSize,
|
| 76 |
+
);
|
| 77 |
+
ctx.drawImage(
|
| 78 |
+
img,
|
| 79 |
+
0,
|
| 80 |
+
i * chunkSize,
|
| 81 |
+
img.width,
|
| 82 |
+
chunkSize,
|
| 83 |
+
0,
|
| 84 |
+
0,
|
| 85 |
+
img.width,
|
| 86 |
+
chunkSize,
|
| 87 |
+
);
|
| 88 |
+
|
| 89 |
+
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
|
| 90 |
+
|
| 91 |
+
for (let j = 0; j < imageData.data.length / 4; j++) {
|
| 92 |
+
// All channels hold an equal value since the image is grayscale, so
|
| 93 |
+
// just read the red channel.
|
| 94 |
+
datasetBytesView[j] = imageData.data[j * 4] / 255;
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
this.datasetImages = new Float32Array(datasetBytesBuffer);
|
| 98 |
+
|
| 99 |
+
resolve();
|
| 100 |
+
};
|
| 101 |
+
img.src = MNIST_IMAGES_SPRITE_PATH;
|
| 102 |
+
});
|
| 103 |
+
|
| 104 |
+
const labelsRequest = fetch(MNIST_LABELS_PATH);
|
| 105 |
+
const [imgResponse, labelsResponse] = await Promise.all([
|
| 106 |
+
imgRequest,
|
| 107 |
+
labelsRequest,
|
| 108 |
+
]);
|
| 109 |
+
|
| 110 |
+
this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
|
| 111 |
+
|
| 112 |
+
// Create shuffled indices into the train/test set for when we select a
|
| 113 |
+
// random dataset element for training / validation.
|
| 114 |
+
this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
|
| 115 |
+
this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);
|
| 116 |
+
|
| 117 |
+
// Slice the the images and labels into train and test sets.
|
| 118 |
+
this.trainImages = this.datasetImages.slice(
|
| 119 |
+
0,
|
| 120 |
+
IMAGE_SIZE * NUM_TRAIN_ELEMENTS,
|
| 121 |
+
);
|
| 122 |
+
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
|
| 123 |
+
this.trainLabels = this.datasetLabels.slice(
|
| 124 |
+
0,
|
| 125 |
+
NUM_CLASSES * NUM_TRAIN_ELEMENTS,
|
| 126 |
+
);
|
| 127 |
+
this.testLabels = this.datasetLabels.slice(
|
| 128 |
+
NUM_CLASSES * NUM_TRAIN_ELEMENTS,
|
| 129 |
+
);
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
nextTrainBatch(batchSize) {
|
| 133 |
+
return this.nextBatch(
|
| 134 |
+
batchSize,
|
| 135 |
+
[this.trainImages, this.trainLabels],
|
| 136 |
+
() => {
|
| 137 |
+
this.shuffledTrainIndex =
|
| 138 |
+
(this.shuffledTrainIndex + 1) % this.trainIndices.length;
|
| 139 |
+
return this.trainIndices[this.shuffledTrainIndex];
|
| 140 |
+
},
|
| 141 |
+
);
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
nextTestBatch(batchSize) {
|
| 145 |
+
return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
|
| 146 |
+
this.shuffledTestIndex =
|
| 147 |
+
(this.shuffledTestIndex + 1) % this.testIndices.length;
|
| 148 |
+
return this.testIndices[this.shuffledTestIndex];
|
| 149 |
+
});
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
nextBatch(batchSize, data, index) {
|
| 153 |
+
const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
|
| 154 |
+
const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);
|
| 155 |
+
|
| 156 |
+
for (let i = 0; i < batchSize; i++) {
|
| 157 |
+
const idx = index();
|
| 158 |
+
|
| 159 |
+
const image = data[0].slice(
|
| 160 |
+
idx * IMAGE_SIZE,
|
| 161 |
+
idx * IMAGE_SIZE + IMAGE_SIZE,
|
| 162 |
+
);
|
| 163 |
+
batchImagesArray.set(image, i * IMAGE_SIZE);
|
| 164 |
+
|
| 165 |
+
const label = data[1].slice(
|
| 166 |
+
idx * NUM_CLASSES,
|
| 167 |
+
idx * NUM_CLASSES + NUM_CLASSES,
|
| 168 |
+
);
|
| 169 |
+
batchLabelsArray.set(label, i * NUM_CLASSES);
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
|
| 173 |
+
const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);
|
| 174 |
+
|
| 175 |
+
return { xs, labels };
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
getTestSample(index) {
|
| 179 |
+
const idx = index % NUM_TEST_ELEMENTS;
|
| 180 |
+
|
| 181 |
+
const image = this.testImages.slice(
|
| 182 |
+
idx * IMAGE_SIZE,
|
| 183 |
+
(idx + 1) * IMAGE_SIZE,
|
| 184 |
+
);
|
| 185 |
+
const label = this.testLabels.slice(
|
| 186 |
+
idx * NUM_CLASSES,
|
| 187 |
+
(idx + 1) * NUM_CLASSES,
|
| 188 |
+
);
|
| 189 |
+
|
| 190 |
+
const xs = tf.tensor2d(image, [1, IMAGE_SIZE]);
|
| 191 |
+
const labels = tf.tensor2d(label, [1, NUM_CLASSES]);
|
| 192 |
+
|
| 193 |
+
return { xs, labels };
|
| 194 |
+
}
|
| 195 |
+
}
|
src/train.ts
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import * as tf from '@tensorflow/tfjs';
|
| 2 |
+
|
| 3 |
+
type LayerValue = string | number | tf.Variable | null | undefined;
|
| 4 |
+
|
| 5 |
+
interface LayerConfig {
|
| 6 |
+
type: string;
|
| 7 |
+
[key: string]: LayerValue;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
export interface RunInfo {
|
| 11 |
+
[key: string]: unknown;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
export interface TrainController {
|
| 15 |
+
isPaused: boolean;
|
| 16 |
+
stopRequested: boolean;
|
| 17 |
+
sampleIndex: number;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
interface BatchData {
|
| 21 |
+
xs: tf.Tensor;
|
| 22 |
+
labels: tf.Tensor2D;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
interface TestSample {
|
| 26 |
+
xs: tf.Tensor;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
interface TrainingData {
|
| 30 |
+
trainSize: number;
|
| 31 |
+
imageSize: number;
|
| 32 |
+
numInputChannels: number;
|
| 33 |
+
nextTrainBatch(batchSize: number): BatchData;
|
| 34 |
+
getTestSample(index: number): TestSample;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
export interface OptimizerParams {
|
| 38 |
+
learningRate: string;
|
| 39 |
+
batchSize: string;
|
| 40 |
+
epochs: string;
|
| 41 |
+
|
| 42 |
+
// sgd only
|
| 43 |
+
momentum?: string;
|
| 44 |
+
// adam only
|
| 45 |
+
beta1?: string;
|
| 46 |
+
beta2?: string;
|
| 47 |
+
epsilon?: string;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
type BatchEndCallback = (
|
| 51 |
+
epoch: number,
|
| 52 |
+
batch: number,
|
| 53 |
+
loss: number,
|
| 54 |
+
info: RunInfo[],
|
| 55 |
+
) => void | Promise<void>;
|
| 56 |
+
|
| 57 |
+
function parseValue(raw: string): string | number {
|
| 58 |
+
if (raw.trim() === '') {
|
| 59 |
+
return raw;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
const num = Number(raw);
|
| 63 |
+
if (!Number.isNaN(num)) {
|
| 64 |
+
return num;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
return raw;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
function parseArchitecture(text: string): LayerConfig[] {
|
| 71 |
+
const layers: LayerConfig[] = [];
|
| 72 |
+
const matches = text.match(/\[(.*?)\]/gs);
|
| 73 |
+
if (!matches) return layers;
|
| 74 |
+
|
| 75 |
+
for (const block of matches) {
|
| 76 |
+
const content = block.slice(1, -1).trim();
|
| 77 |
+
if (content.length === 0) continue;
|
| 78 |
+
|
| 79 |
+
const tokens = content.split(/\s+/);
|
| 80 |
+
if (tokens.length === 0) continue;
|
| 81 |
+
|
| 82 |
+
const type = tokens[0];
|
| 83 |
+
const layer: LayerConfig = { type };
|
| 84 |
+
|
| 85 |
+
for (let i = 1; i < tokens.length; ++i) {
|
| 86 |
+
const token = tokens[i];
|
| 87 |
+
const [rawKey, rawValue] = token.split('=', 2);
|
| 88 |
+
|
| 89 |
+
if (!rawKey || rawValue === undefined) continue;
|
| 90 |
+
|
| 91 |
+
const key = rawKey === 'activation' ? 'activationType' : rawKey;
|
| 92 |
+
layer[key] = parseValue(rawValue);
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
layers.push(layer);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
return layers;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
function getNumber(layer: LayerConfig, key: string): number {
|
| 102 |
+
const value = layer[key];
|
| 103 |
+
if (typeof value !== 'number') {
|
| 104 |
+
throw new Error(`Layer "${layer.type}" is missing numeric "${key}"`);
|
| 105 |
+
}
|
| 106 |
+
return value;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
function getVariable(layer: LayerConfig, key: string): tf.Variable {
|
| 110 |
+
const value = layer[key];
|
| 111 |
+
if (!value || typeof value !== 'object' || !('dispose' in value)) {
|
| 112 |
+
throw new Error(`Layer "${layer.type}" is missing tensor "${key}"`);
|
| 113 |
+
}
|
| 114 |
+
return value as tf.Variable;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
function getPadding(layer: LayerConfig): number | 'same' | 'valid' {
|
| 118 |
+
const padding = layer.padding;
|
| 119 |
+
if (padding === undefined) return 'valid';
|
| 120 |
+
if (typeof padding === 'number') return padding;
|
| 121 |
+
if (padding === 'same' || padding === 'valid') return padding;
|
| 122 |
+
throw new Error(`Layer "${layer.type}" has invalid padding "${String(padding)}"`);
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
function getFlatDim(out: tf.Tensor): number {
|
| 126 |
+
const [, h, w, c] = out.shape;
|
| 127 |
+
if (
|
| 128 |
+
typeof h !== 'number' ||
|
| 129 |
+
typeof w !== 'number' ||
|
| 130 |
+
typeof c !== 'number'
|
| 131 |
+
) {
|
| 132 |
+
throw new Error('Cannot flatten tensor with unknown shape');
|
| 133 |
+
}
|
| 134 |
+
return h * w * c;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
export class Cnn {
|
| 138 |
+
architecture: LayerConfig[];
|
| 139 |
+
inChannels: number;
|
| 140 |
+
weights: tf.Variable[];
|
| 141 |
+
|
| 142 |
+
constructor(architecture: string, inChannels: number) {
|
| 143 |
+
this.architecture = parseArchitecture(architecture);
|
| 144 |
+
this.inChannels = inChannels;
|
| 145 |
+
this.weights = this.initWeights();
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
initWeights(): tf.Variable[] {
|
| 149 |
+
const weights: tf.Variable[] = [];
|
| 150 |
+
let inChannels = this.inChannels;
|
| 151 |
+
|
| 152 |
+
for (const layer of this.architecture) {
|
| 153 |
+
if (layer.type === 'conv2d') {
|
| 154 |
+
const kernel = getNumber(layer, 'kernel');
|
| 155 |
+
const filters = getNumber(layer, 'filters');
|
| 156 |
+
const shape: [number, number, number, number] = [
|
| 157 |
+
kernel,
|
| 158 |
+
kernel,
|
| 159 |
+
inChannels,
|
| 160 |
+
filters,
|
| 161 |
+
];
|
| 162 |
+
const layerWeights = tf.variable(
|
| 163 |
+
tf.randomUniform(
|
| 164 |
+
shape,
|
| 165 |
+
-Math.sqrt(1 / (kernel * kernel * inChannels)),
|
| 166 |
+
Math.sqrt(1 / (kernel * kernel * inChannels)),
|
| 167 |
+
),
|
| 168 |
+
);
|
| 169 |
+
|
| 170 |
+
weights.push(layerWeights);
|
| 171 |
+
layer.weights = layerWeights;
|
| 172 |
+
inChannels = filters;
|
| 173 |
+
} else if (layer.type === 'dense') {
|
| 174 |
+
layer.weights = null;
|
| 175 |
+
layer.biases = null;
|
| 176 |
+
}
|
| 177 |
+
}
|
| 178 |
+
return weights;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
dispose(): void {
|
| 182 |
+
for (const layer of this.architecture) {
|
| 183 |
+
if (layer.type === 'conv2d') {
|
| 184 |
+
getVariable(layer, 'weights').dispose();
|
| 185 |
+
} else if (layer.type === 'dense') {
|
| 186 |
+
const weights = layer.weights;
|
| 187 |
+
const biases = layer.biases;
|
| 188 |
+
if (weights && typeof weights === 'object' && 'dispose' in weights) {
|
| 189 |
+
(weights as tf.Variable).dispose();
|
| 190 |
+
}
|
| 191 |
+
if (biases && typeof biases === 'object' && 'dispose' in biases) {
|
| 192 |
+
(biases as tf.Variable).dispose();
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
forward(x: tf.Tensor4D): tf.Tensor {
|
| 199 |
+
let out: tf.Tensor = x;
|
| 200 |
+
|
| 201 |
+
for (let i = 0; i < this.architecture.length; i += 1) {
|
| 202 |
+
const layer = this.architecture[i];
|
| 203 |
+
|
| 204 |
+
switch (layer.type) {
|
| 205 |
+
case 'conv2d': {
|
| 206 |
+
const layerWeights = getVariable(layer, 'weights');
|
| 207 |
+
const stride = getNumber(layer, 'stride');
|
| 208 |
+
const padding = getPadding(layer);
|
| 209 |
+
out = tf.conv2d(
|
| 210 |
+
out as tf.Tensor4D,
|
| 211 |
+
layerWeights as tf.Tensor4D,
|
| 212 |
+
stride,
|
| 213 |
+
padding,
|
| 214 |
+
);
|
| 215 |
+
if (layer.activationType === 'relu') {
|
| 216 |
+
out = out.relu();
|
| 217 |
+
}
|
| 218 |
+
break;
|
| 219 |
+
}
|
| 220 |
+
case 'maxpool': {
|
| 221 |
+
const size = getNumber(layer, 'size');
|
| 222 |
+
const stride = getNumber(layer, 'stride');
|
| 223 |
+
out = tf.maxPool(out as tf.Tensor4D, [size, size], [stride, stride], 0);
|
| 224 |
+
break;
|
| 225 |
+
}
|
| 226 |
+
case 'flatten': {
|
| 227 |
+
const flatDim = getFlatDim(out);
|
| 228 |
+
out = out.reshape([-1, flatDim]);
|
| 229 |
+
|
| 230 |
+
const next = this.architecture[i + 1];
|
| 231 |
+
if (next?.type === 'dense' && next.weights === null) {
|
| 232 |
+
const units = getNumber(next, 'units');
|
| 233 |
+
next.weights = tf.variable(
|
| 234 |
+
tf.randomUniform(
|
| 235 |
+
[flatDim, units],
|
| 236 |
+
-Math.sqrt(1 / flatDim),
|
| 237 |
+
Math.sqrt(1 / flatDim),
|
| 238 |
+
),
|
| 239 |
+
);
|
| 240 |
+
next.biases = tf.variable(tf.zeros([units]));
|
| 241 |
+
}
|
| 242 |
+
break;
|
| 243 |
+
}
|
| 244 |
+
case 'dense': {
|
| 245 |
+
const denseWeights = getVariable(layer, 'weights');
|
| 246 |
+
const denseBiases = getVariable(layer, 'biases');
|
| 247 |
+
out = tf.matMul(out as tf.Tensor2D, denseWeights as tf.Tensor2D).add(
|
| 248 |
+
denseBiases as tf.Tensor1D,
|
| 249 |
+
);
|
| 250 |
+
break;
|
| 251 |
+
}
|
| 252 |
+
default:
|
| 253 |
+
break;
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
return out;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
forwardWithInfo(x: tf.Tensor4D): { output: tf.Tensor; info: RunInfo[] } {
|
| 260 |
+
let out: tf.Tensor = x;
|
| 261 |
+
const info: RunInfo[] = [];
|
| 262 |
+
|
| 263 |
+
info.push({
|
| 264 |
+
type: 'input',
|
| 265 |
+
output: out.dataSync(),
|
| 266 |
+
shape: x.shape,
|
| 267 |
+
});
|
| 268 |
+
|
| 269 |
+
for (let i = 0; i < this.architecture.length; i += 1) {
|
| 270 |
+
const layer = this.architecture[i];
|
| 271 |
+
|
| 272 |
+
switch (layer.type) {
|
| 273 |
+
case 'conv2d': {
|
| 274 |
+
const layerWeights = getVariable(layer, 'weights');
|
| 275 |
+
const stride = getNumber(layer, 'stride');
|
| 276 |
+
const padding = getPadding(layer);
|
| 277 |
+
out = tf.conv2d(
|
| 278 |
+
out as tf.Tensor4D,
|
| 279 |
+
layerWeights as tf.Tensor4D,
|
| 280 |
+
stride,
|
| 281 |
+
padding,
|
| 282 |
+
);
|
| 283 |
+
if (layer.activationType === 'relu') {
|
| 284 |
+
out = out.relu();
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
info.push({
|
| 288 |
+
type: 'conv2d',
|
| 289 |
+
output: out.dataSync(),
|
| 290 |
+
kernels: layerWeights.dataSync(),
|
| 291 |
+
outputShape: out.shape,
|
| 292 |
+
kernelShape: layerWeights.shape,
|
| 293 |
+
stride,
|
| 294 |
+
padding,
|
| 295 |
+
activationType: layer.activationType,
|
| 296 |
+
});
|
| 297 |
+
break;
|
| 298 |
+
}
|
| 299 |
+
case 'maxpool': {
|
| 300 |
+
const size = getNumber(layer, 'size');
|
| 301 |
+
const stride = getNumber(layer, 'stride');
|
| 302 |
+
out = tf.maxPool(out as tf.Tensor4D, [size, size], [stride, stride], 0);
|
| 303 |
+
|
| 304 |
+
info.push({
|
| 305 |
+
type: 'maxpool',
|
| 306 |
+
output: out.dataSync(),
|
| 307 |
+
shape: out.shape,
|
| 308 |
+
size,
|
| 309 |
+
stride,
|
| 310 |
+
});
|
| 311 |
+
break;
|
| 312 |
+
}
|
| 313 |
+
case 'flatten': {
|
| 314 |
+
const flatDim = getFlatDim(out);
|
| 315 |
+
out = out.reshape([-1, flatDim]);
|
| 316 |
+
|
| 317 |
+
info.push({
|
| 318 |
+
type: 'flatten',
|
| 319 |
+
output: out.dataSync(),
|
| 320 |
+
shape: out.shape,
|
| 321 |
+
});
|
| 322 |
+
|
| 323 |
+
const next = this.architecture[i + 1];
|
| 324 |
+
if (next?.type === 'dense' && next.weights === null) {
|
| 325 |
+
const units = getNumber(next, 'units');
|
| 326 |
+
next.weights = tf.variable(
|
| 327 |
+
tf.randomUniform(
|
| 328 |
+
[flatDim, units],
|
| 329 |
+
-Math.sqrt(1 / flatDim),
|
| 330 |
+
Math.sqrt(1 / flatDim),
|
| 331 |
+
),
|
| 332 |
+
);
|
| 333 |
+
next.biases = tf.variable(tf.zeros([units]));
|
| 334 |
+
}
|
| 335 |
+
break;
|
| 336 |
+
}
|
| 337 |
+
case 'dense': {
|
| 338 |
+
const denseWeights = getVariable(layer, 'weights');
|
| 339 |
+
const denseBiases = getVariable(layer, 'biases');
|
| 340 |
+
out = tf.matMul(out as tf.Tensor2D, denseWeights as tf.Tensor2D).add(
|
| 341 |
+
denseBiases as tf.Tensor1D,
|
| 342 |
+
);
|
| 343 |
+
|
| 344 |
+
info.push({
|
| 345 |
+
type: 'dense',
|
| 346 |
+
output: out.dataSync(),
|
| 347 |
+
weights: denseWeights.dataSync(),
|
| 348 |
+
biases: denseBiases.dataSync(),
|
| 349 |
+
outputShape: out.shape,
|
| 350 |
+
weightShape: denseWeights.shape,
|
| 351 |
+
biasShape: denseBiases.shape,
|
| 352 |
+
units: getNumber(layer, 'units'),
|
| 353 |
+
});
|
| 354 |
+
break;
|
| 355 |
+
}
|
| 356 |
+
default:
|
| 357 |
+
break;
|
| 358 |
+
}
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
return { output: out, info };
|
| 362 |
+
}
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
export async function train(
|
| 366 |
+
data: TrainingData,
|
| 367 |
+
model: Cnn,
|
| 368 |
+
optimizer: tf.Optimizer,
|
| 369 |
+
batchSize: number,
|
| 370 |
+
epochs: number,
|
| 371 |
+
controller: TrainController,
|
| 372 |
+
onBatchEnd: BatchEndCallback | null = null,
|
| 373 |
+
): Promise<void> {
|
| 374 |
+
const numBatches = Math.floor(data.trainSize / batchSize);
|
| 375 |
+
for (let epoch = 0; epoch < epochs; ++epoch) {
|
| 376 |
+
for (let b = 0; b < numBatches; ++b) {
|
| 377 |
+
if (controller.stopRequested) {
|
| 378 |
+
console.log('Training stopped');
|
| 379 |
+
return;
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
while (controller.isPaused) {
|
| 383 |
+
await tf.nextFrame();
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
const cost = optimizer.minimize(() => {
|
| 387 |
+
const batch = data.nextTrainBatch(batchSize);
|
| 388 |
+
const xs = batch.xs.reshape([
|
| 389 |
+
batchSize,
|
| 390 |
+
data.imageSize,
|
| 391 |
+
data.imageSize,
|
| 392 |
+
data.numInputChannels,
|
| 393 |
+
]) as tf.Tensor4D;
|
| 394 |
+
const preds = model.forward(xs);
|
| 395 |
+
return tf.losses.softmaxCrossEntropy(batch.labels, preds).mean();
|
| 396 |
+
}, true);
|
| 397 |
+
|
| 398 |
+
if (!cost) {
|
| 399 |
+
throw new Error('Optimizer did not return a loss tensor');
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
const lossVal = (await cost.data())[0];
|
| 403 |
+
cost.dispose();
|
| 404 |
+
|
| 405 |
+
const sample = data.getTestSample(controller.sampleIndex);
|
| 406 |
+
const { output, info } = model.forwardWithInfo(
|
| 407 |
+
sample.xs.reshape([
|
| 408 |
+
1,
|
| 409 |
+
data.imageSize,
|
| 410 |
+
data.imageSize,
|
| 411 |
+
data.numInputChannels,
|
| 412 |
+
]) as tf.Tensor4D,
|
| 413 |
+
);
|
| 414 |
+
|
| 415 |
+
const probs = tf.tidy(() => tf.softmax(output));
|
| 416 |
+
info.push({
|
| 417 |
+
type: 'output',
|
| 418 |
+
output: probs.dataSync(),
|
| 419 |
+
shape: probs.shape,
|
| 420 |
+
});
|
| 421 |
+
|
| 422 |
+
if (controller.stopRequested) {
|
| 423 |
+
console.log('Training stopped');
|
| 424 |
+
probs.dispose();
|
| 425 |
+
return;
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
if (onBatchEnd) {
|
| 429 |
+
await onBatchEnd(epoch, b, lossVal, info);
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
probs.dispose();
|
| 433 |
+
await tf.nextFrame();
|
| 434 |
+
}
|
| 435 |
+
}
|
| 436 |
+
console.log('Training complete');
|
| 437 |
+
}
|
src/types.d.ts
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
declare module "react-plotly.js";
|
src/ui/Button.tsx
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
interface ButtonProps {
|
| 2 |
+
label: string;
|
| 3 |
+
onClick?: () => void;
|
| 4 |
+
}
|
| 5 |
+
|
| 6 |
+
export default function Button({ label, onClick }: ButtonProps) {
|
| 7 |
+
return (
|
| 8 |
+
<button
|
| 9 |
+
onClick={onClick}
|
| 10 |
+
className="px-5 py-2 cursor-pointer bg-orange-200 rounded hover:bg-orange-300 border border-gray-300"
|
| 11 |
+
>
|
| 12 |
+
{label}
|
| 13 |
+
</button>
|
| 14 |
+
);
|
| 15 |
+
}
|
src/ui/Dropdown.tsx
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
interface DropdownProps<T extends string> {
|
| 2 |
+
label: string;
|
| 3 |
+
options: readonly T[];
|
| 4 |
+
activeOption: T;
|
| 5 |
+
onChange: (option: T) => void;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
export default function Dropdown<T extends string>({ label, options, activeOption, onChange }: DropdownProps<T>) {
|
| 9 |
+
return (
|
| 10 |
+
<div className="flex flex-col gap-1">
|
| 11 |
+
<label className="text-gray-700 text-sm">{label}</label>
|
| 12 |
+
<select
|
| 13 |
+
value={activeOption}
|
| 14 |
+
onChange={(e) => onChange(e.target.value as T)}
|
| 15 |
+
className="p-2 rounded bg-white border border-gray-300"
|
| 16 |
+
>
|
| 17 |
+
{options.map((option) => (
|
| 18 |
+
<option key={option} value={option}>
|
| 19 |
+
{option}
|
| 20 |
+
</option>
|
| 21 |
+
))}
|
| 22 |
+
</select>
|
| 23 |
+
</div>
|
| 24 |
+
);
|
| 25 |
+
}
|
src/ui/InputField.tsx
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
interface InputFieldProps {
|
| 2 |
+
label: string;
|
| 3 |
+
value?: string;
|
| 4 |
+
onChange?: (value: string) => void;
|
| 5 |
+
readonly?: boolean;
|
| 6 |
+
rows?: number;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
export default function InputField({
|
| 10 |
+
label,
|
| 11 |
+
value,
|
| 12 |
+
onChange,
|
| 13 |
+
readonly = false,
|
| 14 |
+
rows,
|
| 15 |
+
}: InputFieldProps) {
|
| 16 |
+
const commonClasses =
|
| 17 |
+
"p-2 rounded border border-gray-300 w-full";
|
| 18 |
+
|
| 19 |
+
const bgClass = readonly ? "bg-gray-100" : "bg-white";
|
| 20 |
+
|
| 21 |
+
return (
|
| 22 |
+
<div className="flex flex-col gap-1">
|
| 23 |
+
<label className="text-gray-700 text-sm">
|
| 24 |
+
{label}
|
| 25 |
+
</label>
|
| 26 |
+
|
| 27 |
+
{rows && rows > 1 ? (
|
| 28 |
+
<textarea
|
| 29 |
+
value={value}
|
| 30 |
+
rows={rows}
|
| 31 |
+
readOnly={readonly}
|
| 32 |
+
onChange={(e) => onChange && onChange(e.target.value)}
|
| 33 |
+
className={`${commonClasses} ${bgClass} resize-none`}
|
| 34 |
+
/>
|
| 35 |
+
) : (
|
| 36 |
+
<input
|
| 37 |
+
type="text"
|
| 38 |
+
value={value}
|
| 39 |
+
readOnly={readonly}
|
| 40 |
+
onChange={(e) => onChange && onChange(e.target.value)}
|
| 41 |
+
className={`${commonClasses} ${bgClass}`}
|
| 42 |
+
/>
|
| 43 |
+
)}
|
| 44 |
+
</div>
|
| 45 |
+
);
|
| 46 |
+
}
|
src/ui/LoadingScreen.tsx
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export default function LoadingScreen({ message }: { message: string }) {
|
| 2 |
+
return (
|
| 3 |
+
<div className="fixed inset-0 flex flex-col items-center justify-center bg-slate-50 z-50">
|
| 4 |
+
<div className="flex flex-col items-center space-y-6">
|
| 5 |
+
{/* Animated loading spinner */}
|
| 6 |
+
<div className="relative flex items-center justify-center">
|
| 7 |
+
<div className="w-16 h-16 border-4 border-slate-200 border-t-blue-600 rounded-full animate-spin"></div>
|
| 8 |
+
</div>
|
| 9 |
+
|
| 10 |
+
{/* Loading message */}
|
| 11 |
+
<div className="text-center">
|
| 12 |
+
<h2 className="text-xl font-semibold text-slate-800 tracking-tight">
|
| 13 |
+
{message}...
|
| 14 |
+
</h2>
|
| 15 |
+
</div>
|
| 16 |
+
</div>
|
| 17 |
+
</div>
|
| 18 |
+
);
|
| 19 |
+
}
|
src/ui/Radio.tsx
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
interface RadioProps<T extends string> {
|
| 2 |
+
label?: string;
|
| 3 |
+
options: readonly T[];
|
| 4 |
+
activeOption: T;
|
| 5 |
+
onChange: (option: T) => void;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
export default function Radio<T extends string>({ label, options, activeOption, onChange }: RadioProps<T>) {
|
| 9 |
+
return (
|
| 10 |
+
<div className="flex flex-col gap-1">
|
| 11 |
+
{label && <label className="text-gray-700 text-sm">{label}</label>}
|
| 12 |
+
<div className="flex gap-4">
|
| 13 |
+
{options.map((option) => (
|
| 14 |
+
<label key={option} className="flex items-center gap-1">
|
| 15 |
+
<input
|
| 16 |
+
type="radio"
|
| 17 |
+
value={option}
|
| 18 |
+
checked={activeOption === option}
|
| 19 |
+
onChange={() => onChange(option)}
|
| 20 |
+
/>
|
| 21 |
+
{option}
|
| 22 |
+
</label>
|
| 23 |
+
))}
|
| 24 |
+
</div>
|
| 25 |
+
</div>
|
| 26 |
+
);
|
| 27 |
+
}
|
src/ui/Tabs.tsx
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
interface TabsProps<T extends string> {
|
| 2 |
+
tabs: readonly T[];
|
| 3 |
+
activeTab: T;
|
| 4 |
+
onChange: (tab: T) => void;
|
| 5 |
+
}
|
| 6 |
+
|
| 7 |
+
export default function Tabs<T extends string>({ tabs, activeTab, onChange }: TabsProps<T>) {
|
| 8 |
+
return (
|
| 9 |
+
<div className="flex mb-4">
|
| 10 |
+
{tabs.map((tab) => (
|
| 11 |
+
<button
|
| 12 |
+
key={tab}
|
| 13 |
+
onClick={() => onChange(tab)}
|
| 14 |
+
className={`px-5 py-2 cursor-pointer ${
|
| 15 |
+
activeTab === tab
|
| 16 |
+
? "text-orange-500 border-b-2 border-orange-500"
|
| 17 |
+
: "text-gray-950 hover:bg-gray-200"
|
| 18 |
+
}`}
|
| 19 |
+
>
|
| 20 |
+
{tab}
|
| 21 |
+
</button>
|
| 22 |
+
))}
|
| 23 |
+
</div>
|
| 24 |
+
);
|
| 25 |
+
}
|
src/useConvolutionProcessing.ts
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useMemo, useState } from "react";
|
| 2 |
+
|
| 3 |
+
async function getImageData(imageUrl: string): Promise<ImageData> {
|
| 4 |
+
const image = await new Promise<HTMLImageElement>((resolve, reject) => {
|
| 5 |
+
const img = new Image();
|
| 6 |
+
img.crossOrigin = "anonymous";
|
| 7 |
+
img.onload = () => resolve(img);
|
| 8 |
+
img.onerror = reject;
|
| 9 |
+
img.src = imageUrl;
|
| 10 |
+
})
|
| 11 |
+
|
| 12 |
+
const canvas = document.createElement("canvas")
|
| 13 |
+
canvas.width = image.width;
|
| 14 |
+
canvas.height = image.height;
|
| 15 |
+
|
| 16 |
+
const ctx = canvas.getContext("2d");
|
| 17 |
+
if (!ctx) {
|
| 18 |
+
throw new Error("Failed to get canvas context");
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
ctx.drawImage(image, 0, 0);
|
| 22 |
+
return ctx.getImageData(0, 0, canvas.width, canvas.height);
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
function getImageUrl(imageData: ImageData): string {
|
| 27 |
+
const canvas = document.createElement("canvas");
|
| 28 |
+
canvas.width = imageData.width;
|
| 29 |
+
canvas.height = imageData.height;
|
| 30 |
+
|
| 31 |
+
const ctx = canvas.getContext("2d")
|
| 32 |
+
if (!ctx) {
|
| 33 |
+
throw new Error("Failed to get canvas context");
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
ctx.putImageData(imageData, 0, 0);
|
| 37 |
+
return canvas.toDataURL("image/png");
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
function convertToGrayscale(imageData: ImageData): ImageData {
|
| 41 |
+
const output = new ImageData(
|
| 42 |
+
new Uint8ClampedArray(imageData.data),
|
| 43 |
+
imageData.width,
|
| 44 |
+
imageData.height,
|
| 45 |
+
);
|
| 46 |
+
const data = output.data;
|
| 47 |
+
for (let i = 0; i < data.length; i += 4) {
|
| 48 |
+
const r = data[i];
|
| 49 |
+
const g = data[i + 1];
|
| 50 |
+
const b = data[i + 2];
|
| 51 |
+
const gray = 0.299 * r + 0.587 * g + 0.114 * b;
|
| 52 |
+
data[i] = data[i + 1] = data[i + 2] = gray;
|
| 53 |
+
}
|
| 54 |
+
return output;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
function convolve(imageData: ImageData, kernel: number[][] | number[][][]): ImageData {
|
| 58 |
+
if (Array.isArray(kernel[0][0])) {
|
| 59 |
+
// 3D kernel (color)
|
| 60 |
+
return convolveColor(imageData, kernel as number[][][]);
|
| 61 |
+
} else {
|
| 62 |
+
// 2D kernel (grayscale)
|
| 63 |
+
return convolveGray(imageData, kernel as number[][]);
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
function convolveGray(image: ImageData, kernel: number[][]): ImageData {
|
| 68 |
+
const kernelWidth = kernel[0].length;
|
| 69 |
+
const kernelHeight = kernel.length;
|
| 70 |
+
|
| 71 |
+
const width = image.width;
|
| 72 |
+
const height = image.height;
|
| 73 |
+
const inputData = image.data;
|
| 74 |
+
|
| 75 |
+
const outputWidth = width - kernelWidth + 1;
|
| 76 |
+
const outputHeight = height - kernelHeight + 1;
|
| 77 |
+
const outputData = new Uint8ClampedArray(outputWidth * outputHeight * 4);
|
| 78 |
+
|
| 79 |
+
for (let y = 0; y < outputHeight; ++y) {
|
| 80 |
+
for (let x = 0; x < outputWidth; ++x) {
|
| 81 |
+
// dot product
|
| 82 |
+
let sum = 0;
|
| 83 |
+
for (let ky = 0; ky < kernelHeight; ++ky) {
|
| 84 |
+
for (let kx = 0; kx < kernelWidth; ++kx) {
|
| 85 |
+
const pixelIndex = ((y + ky) * width + (x + kx)) * 4;
|
| 86 |
+
const pixelValue = inputData[pixelIndex];
|
| 87 |
+
const kernelValue = kernel[ky][kx];
|
| 88 |
+
sum += pixelValue * kernelValue;
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
const outputIndex = (y * outputWidth + x) * 4;
|
| 93 |
+
const clampedValue = Math.min(Math.max(sum, 0), 255);
|
| 94 |
+
outputData[outputIndex] = clampedValue; // R
|
| 95 |
+
outputData[outputIndex + 1] = clampedValue; // G
|
| 96 |
+
outputData[outputIndex + 2] = clampedValue; // B
|
| 97 |
+
outputData[outputIndex + 3] = 255; // A
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
return new ImageData(outputData, outputWidth, outputHeight);
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
function convolveColor(image: ImageData, kernel: number[][][]): ImageData {
|
| 105 |
+
const kernelWidth = kernel[0][0].length;
|
| 106 |
+
const kernelHeight = kernel[0].length;
|
| 107 |
+
|
| 108 |
+
const width = image.width;
|
| 109 |
+
const height = image.height;
|
| 110 |
+
const inputData = image.data;
|
| 111 |
+
|
| 112 |
+
const outputWidth = width - kernelWidth + 1;
|
| 113 |
+
const outputHeight = height - kernelHeight + 1;
|
| 114 |
+
const outputData = new Uint8ClampedArray(outputWidth * outputHeight * 4);
|
| 115 |
+
|
| 116 |
+
for (let y = 0; y < outputHeight; ++y) {
|
| 117 |
+
for (let x = 0; x < outputWidth; ++x) {
|
| 118 |
+
// dot product over 3 channels
|
| 119 |
+
let sum = 0;
|
| 120 |
+
for (let ky = 0; ky < kernelHeight; ++ky) {
|
| 121 |
+
for (let kx = 0; kx < kernelWidth; ++kx) {
|
| 122 |
+
const pixelIndex = ((y + ky) * width + (x + kx)) * 4;
|
| 123 |
+
const r = inputData[pixelIndex];
|
| 124 |
+
const g = inputData[pixelIndex + 1];
|
| 125 |
+
const b = inputData[pixelIndex + 2];
|
| 126 |
+
|
| 127 |
+
const kernelR = kernel[0][ky][kx];
|
| 128 |
+
const kernelG = kernel[1][ky][kx];
|
| 129 |
+
const kernelB = kernel[2][ky][kx];
|
| 130 |
+
sum += r * kernelR + g * kernelG + b * kernelB;
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
const outputIndex = (y * outputWidth + x) * 4;
|
| 135 |
+
const clampedValue = Math.min(Math.max(sum, 0), 255);
|
| 136 |
+
outputData[outputIndex] = clampedValue; // R
|
| 137 |
+
outputData[outputIndex + 1] = clampedValue; // G
|
| 138 |
+
outputData[outputIndex + 2] = clampedValue; // B
|
| 139 |
+
outputData[outputIndex + 3] = 255; // A
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
return new ImageData(outputData, outputWidth, outputHeight);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
export default function useConvolutionProcessing(
|
| 147 |
+
rawInputImage: string,
|
| 148 |
+
kernel: number[][][] | number[][],
|
| 149 |
+
): [string | null, string | null] {
|
| 150 |
+
const useColor = Array.isArray(kernel[0][0]); // true if 3D kernel, false if 2D kernel
|
| 151 |
+
|
| 152 |
+
const [rawImageData, setRawImageData] = useState<ImageData | null>(null);
|
| 153 |
+
|
| 154 |
+
// extract input image data (array)
|
| 155 |
+
useEffect(() => {
|
| 156 |
+
let cancelled = false;
|
| 157 |
+
|
| 158 |
+
async function processImage() {
|
| 159 |
+
const imageData = await getImageData(rawInputImage);
|
| 160 |
+
if (!cancelled) {
|
| 161 |
+
setRawImageData(imageData);
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
processImage();
|
| 166 |
+
|
| 167 |
+
return () => {
|
| 168 |
+
cancelled = true;
|
| 169 |
+
}
|
| 170 |
+
}, [rawInputImage]);
|
| 171 |
+
|
| 172 |
+
const processedImageData = useMemo(() => {
|
| 173 |
+
if (!rawImageData) return null;
|
| 174 |
+
|
| 175 |
+
return useColor ? rawImageData : convertToGrayscale(rawImageData);
|
| 176 |
+
}, [rawImageData, useColor]);
|
| 177 |
+
|
| 178 |
+
const outputImageData = useMemo(() => {
|
| 179 |
+
if (!processedImageData) return null;
|
| 180 |
+
|
| 181 |
+
return convolve(processedImageData, kernel);
|
| 182 |
+
}, [processedImageData, kernel]);
|
| 183 |
+
|
| 184 |
+
const inputImage = useMemo(() => {
|
| 185 |
+
if (!processedImageData) return null;
|
| 186 |
+
|
| 187 |
+
return getImageUrl(processedImageData);
|
| 188 |
+
}, [processedImageData]);
|
| 189 |
+
|
| 190 |
+
const outputImage = useMemo(() => {
|
| 191 |
+
if (!outputImageData) return null;
|
| 192 |
+
|
| 193 |
+
return getImageUrl(outputImageData);
|
| 194 |
+
}, [outputImageData]);
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
return [inputImage, outputImage];
|
| 198 |
+
}
|
tsconfig.app.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compilerOptions": {
|
| 3 |
+
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
|
| 4 |
+
"target": "ES2022",
|
| 5 |
+
"useDefineForClassFields": true,
|
| 6 |
+
"lib": ["ES2022", "DOM", "DOM.Iterable"],
|
| 7 |
+
"module": "ESNext",
|
| 8 |
+
"types": ["vite/client"],
|
| 9 |
+
"skipLibCheck": true,
|
| 10 |
+
|
| 11 |
+
/* Bundler mode */
|
| 12 |
+
"moduleResolution": "bundler",
|
| 13 |
+
"allowImportingTsExtensions": true,
|
| 14 |
+
"verbatimModuleSyntax": true,
|
| 15 |
+
"moduleDetection": "force",
|
| 16 |
+
"noEmit": true,
|
| 17 |
+
"jsx": "react-jsx",
|
| 18 |
+
|
| 19 |
+
/* Linting */
|
| 20 |
+
"strict": true,
|
| 21 |
+
"noUnusedLocals": true,
|
| 22 |
+
"noUnusedParameters": true,
|
| 23 |
+
"erasableSyntaxOnly": true,
|
| 24 |
+
"noFallthroughCasesInSwitch": true,
|
| 25 |
+
"noUncheckedSideEffectImports": true
|
| 26 |
+
},
|
| 27 |
+
"include": ["src"]
|
| 28 |
+
}
|
tsconfig.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"files": [],
|
| 3 |
+
"references": [
|
| 4 |
+
{ "path": "./tsconfig.app.json" },
|
| 5 |
+
{ "path": "./tsconfig.node.json" }
|
| 6 |
+
],
|
| 7 |
+
"compilerOptions": {
|
| 8 |
+
"allowJs": true,
|
| 9 |
+
"checkJs": false
|
| 10 |
+
}
|
| 11 |
+
}
|
tsconfig.node.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compilerOptions": {
|
| 3 |
+
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
|
| 4 |
+
"target": "ES2023",
|
| 5 |
+
"lib": ["ES2023"],
|
| 6 |
+
"module": "ESNext",
|
| 7 |
+
"types": ["node"],
|
| 8 |
+
"skipLibCheck": true,
|
| 9 |
+
|
| 10 |
+
/* Bundler mode */
|
| 11 |
+
"moduleResolution": "bundler",
|
| 12 |
+
"allowImportingTsExtensions": true,
|
| 13 |
+
"verbatimModuleSyntax": true,
|
| 14 |
+
"moduleDetection": "force",
|
| 15 |
+
"noEmit": true,
|
| 16 |
+
|
| 17 |
+
/* Linting */
|
| 18 |
+
"strict": true,
|
| 19 |
+
"noUnusedLocals": true,
|
| 20 |
+
"noUnusedParameters": true,
|
| 21 |
+
"erasableSyntaxOnly": true,
|
| 22 |
+
"noFallthroughCasesInSwitch": true,
|
| 23 |
+
"noUncheckedSideEffectImports": true
|
| 24 |
+
},
|
| 25 |
+
"include": ["vite.config.ts"]
|
| 26 |
+
}
|
vite.config.js → vite.config.ts
RENAMED
|
@@ -1,7 +1,11 @@
|
|
| 1 |
import { defineConfig } from 'vite'
|
| 2 |
import react from '@vitejs/plugin-react'
|
|
|
|
| 3 |
|
| 4 |
// https://vite.dev/config/
|
| 5 |
export default defineConfig({
|
| 6 |
-
plugins: [
|
|
|
|
|
|
|
|
|
|
| 7 |
})
|
|
|
|
| 1 |
import { defineConfig } from 'vite'
|
| 2 |
import react from '@vitejs/plugin-react'
|
| 3 |
+
import tailwindcss from '@tailwindcss/vite'
|
| 4 |
|
| 5 |
// https://vite.dev/config/
|
| 6 |
export default defineConfig({
|
| 7 |
+
plugins: [
|
| 8 |
+
react(),
|
| 9 |
+
tailwindcss(),
|
| 10 |
+
],
|
| 11 |
})
|