|
'use client'; |
|
|
|
import { useEffect, useRef, useState, useCallback } from 'react'; |
|
import Progress from './components/Progress'; |
|
import { modelConfigMap } from './components/modelConfig'; |
|
import { ModelInput } from './components/ModelInput'; |
|
|
|
export default function Home() { |
|
const [result, setResult] = useState<any | null>(null); |
|
const [ready, setReady] = useState<boolean | null>(null); |
|
const [progressItems, setProgressItems] = useState<any[]>([]); |
|
const [input, setInput] = useState(''); |
|
const [task, setTask] = useState('text-classification'); |
|
const [modelName, setModelName] = useState(() => modelConfigMap['text-classification'].defaultModel); |
|
const [currentModel, setCurrentModel] = useState(modelName); |
|
const worker = useRef<Worker | null>(null); |
|
const [image, setImage] = useState<File | null>(null); |
|
|
|
|
|
useEffect(() => { |
|
const defaultModel = modelConfigMap[task].defaultModel; |
|
setModelName(defaultModel); |
|
setCurrentModel(defaultModel); |
|
}, [task]); |
|
|
|
useEffect(() => { |
|
if (!worker.current) { |
|
worker.current = new Worker(new URL('./worker.js', import.meta.url), { |
|
type: 'module' |
|
}); |
|
} |
|
const onMessageReceived = (e: MessageEvent) => { |
|
switch (e.data.status) { |
|
case 'initiate': |
|
setReady(false); |
|
setProgressItems(prev => [...prev, { ...e.data, progress: 0 }]); |
|
break; |
|
case 'progress': |
|
setProgressItems(prev => prev.map(item => { |
|
if (item.file === e.data.file) { |
|
return { |
|
...item, |
|
progress: e.data.progress, |
|
loaded: e.data.loaded, |
|
total: e.data.total, |
|
name: e.data.name |
|
}; |
|
} |
|
return item; |
|
})); |
|
break; |
|
case 'done': |
|
setProgressItems(prev => { |
|
const updated = prev.map(item => |
|
item.file === e.data.file ? { ...item, done: true, progress: 100 } : item |
|
); |
|
setTimeout(() => { |
|
setProgressItems(current => |
|
current.filter(item => item.file !== e.data.file) |
|
); |
|
}, 1000); |
|
return updated; |
|
}); |
|
break; |
|
case 'ready': |
|
setReady(true); |
|
setCurrentModel(e.data.file || modelName); |
|
setProgressItems(prev => prev.filter(item => item.file !== e.data.file)); |
|
break; |
|
case 'complete': |
|
setResult(e.data.output); |
|
break; |
|
case 'error': |
|
setResult({ label: 'Error', score: 0, error: e.data.error }); |
|
break; |
|
} |
|
}; |
|
|
|
worker.current.addEventListener('message', onMessageReceived); |
|
return () => worker.current?.removeEventListener('message', onMessageReceived); |
|
}, [modelName]); |
|
|
|
const classify = useCallback((inputValue: string | Blob) => { |
|
if (worker.current) { |
|
worker.current.postMessage({ |
|
input: inputValue, |
|
modelName: currentModel, |
|
task, |
|
}); |
|
} |
|
}, [currentModel, task]); |
|
|
|
const handleLoadModel = () => { |
|
setReady(false); |
|
setResult(null); |
|
setProgressItems([]); |
|
setCurrentModel(modelName); |
|
if (worker.current) { |
|
worker.current.postMessage({ action: 'load-model', modelName, task }); |
|
} |
|
}; |
|
|
|
useEffect(() => { |
|
setResult(null); |
|
setInput(''); |
|
setImage(null); |
|
setModelName(modelConfigMap[task].defaultModel); |
|
}, [task]); |
|
|
|
const InputComponent = modelConfigMap[task].inputComponent; |
|
const OutputComponent = modelConfigMap[task].outputComponent; |
|
|
|
return ( |
|
<main className="min-h-screen w-full bg-transparent backdrop-blur-sm"> |
|
<div className="max-w-7xl mx-auto px-4 sm:px-6 lg:px-8 py-12"> |
|
<div className="text-center mb-12 transform transition-all duration-500 ease-in-out"> |
|
<h1 className="text-4xl md:text-5xl font-bold text-gray-800 mb-2 hover:text-blue-600 transition-colors"> |
|
Transformers.js Playground |
|
</h1> |
|
<p className="text-gray-500 text-lg"> |
|
Powered by Transformers.js & Next.js (Local browser inference) |
|
</p> |
|
</div> |
|
<div className="mb-6 flex justify-center"> |
|
<select |
|
value={task} |
|
onChange={e => setTask(e.target.value)} |
|
className="p-2 rounded border border-gray-300 text-lg font-medium shadow-sm transition" |
|
> |
|
<option value="text-classification">Text Classification</option> |
|
<option value="image-classification">Image Classification</option> |
|
<option value="automatic-speech-recognition">Automatic Speech Recognition</option> |
|
</select> |
|
</div> |
|
<ModelInput |
|
currentModel={modelName} |
|
onModelChange={setModelName} |
|
onLoadModel={handleLoadModel} |
|
ready={ready} |
|
defaultModel={modelConfigMap[task].defaultModel} // Add defaultModel her |
|
|
|
/> |
|
<div className="grid grid-cols-1 lg:grid-cols-2 gap-8"> |
|
<div className="bg-white rounded-xl shadow-sm border border-gray-200 p-6 transition-all duration-300 hover:shadow-md"> |
|
<InputComponent |
|
input={input} |
|
setInput={setInput} |
|
classify={classify} |
|
ready={ready} |
|
image={image} |
|
setImage={setImage} |
|
/> |
|
</div> |
|
<div className="bg-white rounded-xl shadow-sm border border-gray-200 p-6 transition-all duration-300 hover:shadow-md flex flex-col"> |
|
<h3 className="text-gray-600 mb-4 text-sm font-medium">Result</h3> |
|
<div className="flex-1 bg-gray-50 rounded-lg p-4 flex items-center justify-center"> |
|
<OutputComponent |
|
result={result} |
|
ready={ready} |
|
task={task} |
|
/> |
|
</div> |
|
</div> |
|
</div> |
|
|
|
{ready === false && ( |
|
<div className="mt-12 bg-white rounded-xl shadow-sm border border-gray-200 p-6 max-w-7xl mx-auto"> |
|
<h3 className="text-gray-600 mb-6 text-xl font-medium">Loading Model</h3> |
|
<div className="space-y-6"> |
|
{progressItems.map((data, i) => ( |
|
<div key={i} className="transform transition-all duration-300"> |
|
<Progress |
|
text={`${data.name || ''} - ${data.file}`} |
|
percentage={data.progress} |
|
loaded={data.loaded} |
|
total={data.total} |
|
done={data.done} |
|
/> |
|
</div> |
|
))} |
|
</div> |
|
</div> |
|
)} |
|
</div> |
|
</main> |
|
); |
|
} |
|
|