Spaces:
Running
Running
| import { useEffect, useState, useRef } from "react"; | |
| import * as tf from "@tensorflow/tfjs"; | |
| import Plot from "react-plotly.js"; | |
| import { MnistData } from "./mnist.js"; | |
| import { Cnn, train } from "./train.ts"; | |
| import type { TrainController, RunInfo, OptimizerParams } from "./train.ts"; | |
| import { Button, InputField, Tabs, Dropdown, Card } from "@elvis/ui"; | |
| import InfoViewer from "./InfoViewer.tsx"; | |
| const DEFAULT_ARCHITECTURE = `[conv2d filters=8 kernel=11 | |
| stride=1 padding=1 activation=relu] | |
| [maxpool size=2 stride=2] | |
| [flatten] | |
| [dense units=10 activation=softmax]`; | |
| const isFirefox = navigator.userAgent.toLowerCase().includes("firefox"); | |
| await tf.setBackend(isFirefox ? "cpu" : "webgl"); | |
| await tf.ready(); | |
| export default function NetworkVisualizer() { | |
| const [dataset, setDataset] = useState<MnistData | null>(null); | |
| useEffect(() => { | |
| const loadData = async () => { | |
| const data = new MnistData(); | |
| await data.load(); | |
| setDataset(data); | |
| console.log("dataset loaded"); | |
| } | |
| loadData(); | |
| }, []); | |
| // architecture states | |
| const [architecture, setArchitecture] = useState(DEFAULT_ARCHITECTURE); | |
| const [optimizerType, setOptimizerType] = useState<string>('adam'); | |
| const [optimizerParams, setOptimizerParams] = useState<OptimizerParams>({ | |
| learningRate: '0.001', | |
| beta1: '0.9', | |
| beta2: '0.999', | |
| epsilon: '1e-8', | |
| batchSize: '32', | |
| epochs: '5', | |
| }); | |
| const modelRef = useRef<Cnn | null>(null); | |
| const optimizerRef = useRef<tf.Optimizer | null>(null); | |
| function handleArchitectureChange(newArchitecture: string) { | |
| if (isTraining) { | |
| alert('Cannot change architecture while training is in progress.'); | |
| } else { | |
| setArchitecture(newArchitecture); | |
| } | |
| } | |
| function handleOptimizerChange(newOptimizerType: string, newOptimizerParams: OptimizerParams) { | |
| if (isTraining) { | |
| alert('Cannot change optimizer settings while training is in progress.'); | |
| } else { | |
| setOptimizerType(newOptimizerType); | |
| setOptimizerParams(newOptimizerParams); | |
| } | |
| } | |
| function handleSampleIndexChange() { | |
| trainController.current.sampleIndex += 1; | |
| updateTick(); | |
| } | |
| function resetModel() { | |
| if (!dataset) return; | |
| if (modelRef.current) { | |
| modelRef.current.dispose(); | |
| } | |
| const cnn = new Cnn(architecture, dataset.numInputChannels); | |
| modelRef.current = cnn; | |
| } | |
| function resetOptimizer() { | |
| if (optimizerType === 'adam') { | |
| const learningRate = parseFloat(optimizerParams.learningRate); | |
| const beta1 = parseFloat(optimizerParams.beta1 || "0.9"); | |
| const beta2 = parseFloat(optimizerParams.beta2 || "0.999"); | |
| const epsilon = parseFloat(optimizerParams.epsilon || "1e-8"); | |
| if (Number.isNaN(learningRate) || learningRate <= 0) { | |
| alert('Invalid learning rate for Adam optimizer.'); | |
| return; | |
| } | |
| if (Number.isNaN(beta1) || beta1 < 0) { | |
| alert('Invalid beta1 for Adam optimizer.'); | |
| return; | |
| } | |
| if (Number.isNaN(beta2) || beta2 < 0) { | |
| alert('Invalid beta2 for Adam optimizer.'); | |
| return; | |
| } | |
| if (Number.isNaN(epsilon) || epsilon <= 0) { | |
| alert('Invalid epsilon for Adam optimizer.'); | |
| return; | |
| } | |
| const opt = tf.train.adam(learningRate, beta1, beta2, epsilon); | |
| if (optimizerRef.current) { | |
| optimizerRef.current.dispose(); | |
| } | |
| optimizerRef.current = opt; | |
| } else if (optimizerType === 'sgd') { | |
| const learningRate = parseFloat(optimizerParams.learningRate); | |
| if (Number.isNaN(learningRate) || learningRate <= 0) { | |
| alert('Invalid learning rate for SGD optimizer.'); | |
| return; | |
| } | |
| const opt = tf.train.sgd(learningRate); | |
| if (optimizerRef.current) { | |
| optimizerRef.current.dispose(); | |
| } | |
| optimizerRef.current = opt; | |
| } else { | |
| alert(`Unsupported optimizer type: ${optimizerType}`); | |
| } | |
| } | |
| // reset & init model and optimizer | |
| useEffect(() => { | |
| resetModel(); | |
| resetOptimizer(); | |
| }, [architecture, optimizerType, optimizerParams, dataset]); | |
| // training states | |
| const [isTraining, setIsTraining] = useState<boolean>(false); | |
| const lossesRef = useRef<Array<number>>([]); | |
| const trainController = useRef<TrainController>({ | |
| isPaused: false, | |
| stopRequested: false, | |
| sampleIndex: 0, | |
| }); | |
| const infoRef = useRef<Array<RunInfo>>([]); | |
| // render timing | |
| const [, setTick] = useState<number>(0); | |
| function updateTick() { | |
| setTick((tick) => tick + 1); | |
| } | |
| async function startTraining() { | |
| if (!modelRef || !dataset || !optimizerRef || isTraining) { | |
| return; | |
| } | |
| setIsTraining(true); | |
| trainController.current.isPaused = false; | |
| trainController.current.stopRequested = false; | |
| const batchSize = parseFloat(optimizerParams.batchSize); | |
| if (Number.isNaN(batchSize) || batchSize <= 0) { | |
| alert('Invalid batch size.'); | |
| setIsTraining(false); | |
| return; | |
| } | |
| const epochs = parseFloat(optimizerParams.epochs); | |
| if (Number.isNaN(epochs) || epochs <= 0) { | |
| alert('Invalid number of epochs.'); | |
| setIsTraining(false); | |
| return; | |
| } | |
| let lastTickUpdate = 0; | |
| if (!modelRef.current) return; | |
| if (!optimizerRef.current) return; | |
| try { | |
| await train( | |
| dataset, | |
| modelRef.current, | |
| optimizerRef.current, | |
| batchSize, | |
| epochs, | |
| trainController.current, | |
| (_epoch, _batch, loss, info) => { | |
| // lossesRef.current.push({ epoch, batch, loss }); | |
| lossesRef.current.push(loss); | |
| console.log(loss); | |
| infoRef.current = info; | |
| // update tick every 50ms | |
| const now = performance.now(); | |
| if (now - lastTickUpdate > 50) { | |
| lastTickUpdate = now; | |
| updateTick(); | |
| } | |
| }, | |
| ); | |
| } finally { | |
| setIsTraining(false); | |
| trainController.current.isPaused = false; | |
| trainController.current.stopRequested = false; | |
| alert('Training finished.'); | |
| } | |
| } | |
| function handleStartTraining() { | |
| console.log('Starting training...'); | |
| // trainController updated in startTraining | |
| startTraining(); | |
| } | |
| function handlePauseTraining() { | |
| console.log('Pausing training...'); | |
| trainController.current.isPaused = true; | |
| } | |
| function handleContinueTraining() { | |
| console.log('Continuing training...'); | |
| trainController.current.isPaused = false; | |
| } | |
| function handleStopTraining() { | |
| console.log('Stopping training...'); | |
| trainController.current.stopRequested = true; | |
| trainController.current.isPaused = false; | |
| } | |
| async function waitUntilNotTraining() { | |
| return new Promise<void>((resolve) => { | |
| function check() { | |
| if (!isTraining) { | |
| resolve(); | |
| } else { | |
| requestAnimationFrame(check); | |
| } | |
| } | |
| check(); | |
| }); | |
| } | |
| async function handleResetTraining() { | |
| console.log('Resetting training...'); | |
| handleStopTraining(); | |
| await waitUntilNotTraining(); | |
| console.log('Training stopped. Resetting model.'); | |
| lossesRef.current = []; | |
| infoRef.current = []; | |
| resetModel(); | |
| resetOptimizer(); | |
| updateTick(); | |
| } | |
| return ( | |
| <div className="grid grid-cols-[2fr_1fr] min-h-0 h-full gap-12"> | |
| <TrainingViewer | |
| isTraining={isTraining} | |
| lossesRef={lossesRef} | |
| infoRef={infoRef} | |
| handleSampleIndexChange={handleSampleIndexChange} | |
| /> | |
| <Sidebar | |
| architecture={architecture} | |
| onArchitectureChange={handleArchitectureChange} | |
| optimizerType={optimizerType} | |
| optimizerParams={optimizerParams} | |
| onOptimizerChange={handleOptimizerChange} | |
| onStartTraining={handleStartTraining} | |
| onPauseTraining={handlePauseTraining} | |
| onContinueTraining={handleContinueTraining} | |
| onStopTraining={handleStopTraining} | |
| onResetTraining={handleResetTraining} | |
| /> | |
| </div> | |
| ); | |
| } | |
| interface TrainingViewerProps { | |
| isTraining: boolean; | |
| lossesRef: React.RefObject<Array<number>>; | |
| infoRef: React.RefObject<Array<RunInfo>>; | |
| handleSampleIndexChange: () => void; | |
| } | |
| function TrainingViewer({ | |
| isTraining, | |
| lossesRef, | |
| infoRef, | |
| handleSampleIndexChange, | |
| }: TrainingViewerProps) { | |
| return ( | |
| <div className="flex flex-col h-full overflow-auto gap-4 w-full min-h-0"> | |
| <Card className="flex flex-col gap-4"> | |
| <p>Training { isTraining ? "in progress" : "not in progress" }</p> | |
| <Plot | |
| data={[ | |
| { | |
| x: lossesRef.current.map((_, i) => i), | |
| y: lossesRef.current, | |
| mode: 'lines', | |
| type: 'scatter', | |
| }, | |
| ]} | |
| layout={{ | |
| xaxis: { title: { text: 'Training steps' } }, | |
| yaxis: { title: { text: 'Train loss' } }, | |
| margin: { t: 40, r: 40, b: 40, l: 40 }, | |
| }} | |
| className="w-full h-[320px]" | |
| config={{ responsive: true }} | |
| /> | |
| </Card> | |
| <InfoViewer info={infoRef.current} onSampleIndexChange={handleSampleIndexChange} /> | |
| </div> | |
| ) | |
| } | |
| interface SidebarProps { | |
| architecture: string; | |
| onArchitectureChange: (newArchitecture: string) => void; | |
| optimizerType: string; | |
| optimizerParams: OptimizerParams; | |
| onOptimizerChange: (newOptimizerType: string, newOptimizerParams: OptimizerParams) => void; | |
| onStartTraining: () => void; | |
| onPauseTraining: () => void; | |
| onContinueTraining: () => void; | |
| onStopTraining: () => void; | |
| onResetTraining: () => void; | |
| } | |
| function Sidebar({ | |
| architecture, | |
| onArchitectureChange, | |
| optimizerType, | |
| optimizerParams, | |
| onOptimizerChange, | |
| onStartTraining, | |
| onPauseTraining, | |
| onContinueTraining, | |
| onStopTraining, | |
| onResetTraining, | |
| }: SidebarProps) { | |
| const tabs = ["Architecture", "Train"]; | |
| const [activeTab, setActiveTab] = useState<string>(tabs[0]); | |
| const [architectureDraft, setArchitectureDraft] = useState<string>(architecture); | |
| return ( | |
| <Card className="flex flex-col h-full p-4 gap-2 overflow-auto"> | |
| <Tabs tabs={tabs} activeTab={activeTab} onChange={setActiveTab} /> | |
| <div className="flex flex-col p-4 gap-4 h-full overflow-auto"> | |
| { isFirefox && ( | |
| <p className="text-red-500"> | |
| Warning: This demo may be quite slow on Firefox. | |
| </p> | |
| )} | |
| { activeTab === "Architecture" && ( | |
| <> | |
| <InputField | |
| label="Architecture" | |
| value={architectureDraft} | |
| onChange={setArchitectureDraft} | |
| rows={15} | |
| /> | |
| <Button | |
| label="Apply architecture" | |
| onClick={() => onArchitectureChange(architectureDraft)} | |
| /> | |
| </> | |
| )} | |
| { activeTab === "Train" && ( | |
| <> | |
| <Dropdown | |
| label="Optimizer" | |
| options={["sgd", "adam"]} | |
| activeOption={optimizerType} | |
| onChange={(newOptimizerType) => onOptimizerChange(newOptimizerType, optimizerParams)} | |
| /> | |
| { optimizerType === 'sgd' && ( | |
| <InputField | |
| label="Learning Rate" | |
| value={optimizerParams.learningRate} | |
| onChange={(newLearningRate) => onOptimizerChange(optimizerType, {...optimizerParams, learningRate: newLearningRate})} | |
| /> | |
| )} | |
| { optimizerType === 'adam' && ( | |
| <> | |
| <InputField | |
| label="Learning Rate" | |
| value={optimizerParams.learningRate} | |
| onChange={(newLearningRate) => onOptimizerChange(optimizerType, {...optimizerParams, learningRate: newLearningRate})} | |
| /> | |
| <InputField | |
| label="Beta 1" | |
| value={optimizerParams.beta1} | |
| onChange={(newBeta1) => onOptimizerChange(optimizerType, {...optimizerParams, beta1: newBeta1})} | |
| /> | |
| <InputField | |
| label="Beta 2" | |
| value={optimizerParams.beta2} | |
| onChange={(newBeta2) => onOptimizerChange(optimizerType, {...optimizerParams, beta2: newBeta2})} | |
| /> | |
| <InputField | |
| label="Epsilon" | |
| value={optimizerParams.epsilon} | |
| onChange={(newEpsilon) => onOptimizerChange(optimizerType, {...optimizerParams, epsilon: newEpsilon})} | |
| /> | |
| </> | |
| )} | |
| <InputField | |
| label="Batch Size" | |
| value={optimizerParams.batchSize} | |
| onChange={(newBatchSize) => onOptimizerChange(optimizerType, {...optimizerParams, batchSize: newBatchSize})} | |
| /> | |
| <InputField | |
| label="Epochs" | |
| value={optimizerParams.epochs} | |
| onChange={(newEpochs) => onOptimizerChange(optimizerType, {...optimizerParams, epochs: newEpochs})} | |
| /> | |
| <Button | |
| label="Start training" | |
| onClick={onStartTraining} | |
| /> | |
| <Button | |
| label="Pause training" | |
| onClick={onPauseTraining} | |
| /> | |
| <Button | |
| label="Continue training" | |
| onClick={onContinueTraining} | |
| /> | |
| <Button | |
| label="Stop training" | |
| onClick={onStopTraining} | |
| /> | |
| <Button | |
| label="Reset training" | |
| onClick={onResetTraining} | |
| /> | |
| </> | |
| )} | |
| </div> | |
| </Card> | |
| ) | |
| } | |