Vokturz commited on
Commit
22f8eb7
·
1 Parent(s): 0c10cf2

add feature-extraction

Browse files
public/workers/feature-extraction.js ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* eslint-disable no-restricted-globals */
2
+ import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.6.3'
3
+
4
+ class MyFeatureExtractionPipeline {
5
+ static task = 'feature-extraction'
6
+ static instance = null
7
+
8
+ static async getInstance(model, dtype = 'fp32', progress_callback = null) {
9
+ try {
10
+ // Try WebGPU first
11
+ this.instance = await pipeline(this.task, model, {
12
+ dtype,
13
+ device: 'webgpu',
14
+ progress_callback
15
+ })
16
+ return this.instance
17
+ } catch (webgpuError) {
18
+ // Fallback to WASM if WebGPU fails
19
+ if (progress_callback) {
20
+ progress_callback({
21
+ status: 'fallback',
22
+ message: 'WebGPU failed, falling back to WASM'
23
+ })
24
+ }
25
+ try {
26
+ this.instance = await pipeline(this.task, model, {
27
+ dtype,
28
+ device: 'wasm',
29
+ progress_callback
30
+ })
31
+ return this.instance
32
+ } catch (wasmError) {
33
+ throw new Error(
34
+ `Both WebGPU and WASM failed. WebGPU error: ${webgpuError.message}. WASM error: ${wasmError.message}`
35
+ )
36
+ }
37
+ }
38
+ }
39
+ }
40
+
41
+ // Listen for messages from the main thread
42
+ self.addEventListener('message', async (event) => {
43
+ try {
44
+ const { type, model, dtype, texts, config } = event.data
45
+
46
+ if (!model) {
47
+ self.postMessage({
48
+ status: 'error',
49
+ output: 'No model provided'
50
+ })
51
+ return
52
+ }
53
+
54
+ // Get the pipeline instance
55
+ const extractor = await MyFeatureExtractionPipeline.getInstance(
56
+ model,
57
+ dtype,
58
+ (x) => {
59
+ self.postMessage({ status: 'loading', output: x })
60
+ }
61
+ )
62
+
63
+ if (type === 'load') {
64
+ self.postMessage({
65
+ status: 'ready',
66
+ output: `Feature extraction model ${model}, dtype ${dtype} loaded`
67
+ })
68
+ return
69
+ }
70
+
71
+ if (type === 'extract') {
72
+ if (!texts || !Array.isArray(texts) || texts.length === 0) {
73
+ self.postMessage({
74
+ status: 'error',
75
+ output: 'No texts provided for feature extraction'
76
+ })
77
+ return
78
+ }
79
+
80
+ const embeddings = []
81
+
82
+ for (let i = 0; i < texts.length; i++) {
83
+ const text = texts[i]
84
+ try {
85
+ const output = await extractor(text, config)
86
+
87
+ // Convert tensor to array and get the embedding
88
+ let embedding
89
+ if (output && typeof output.tolist === 'function') {
90
+ embedding = output.tolist()
91
+ } else if (Array.isArray(output)) {
92
+ embedding = output
93
+ } else if (output && output.data) {
94
+ embedding = Array.from(output.data)
95
+ } else {
96
+ throw new Error('Unexpected output format from feature extraction')
97
+ }
98
+
99
+ // If the embedding is 2D (batch dimension), take the first element
100
+ if (Array.isArray(embedding[0])) {
101
+ embedding = embedding[0]
102
+ }
103
+
104
+ embeddings.push({
105
+ text: text,
106
+ embedding: embedding,
107
+ index: i
108
+ })
109
+
110
+ // Send progress update
111
+ self.postMessage({
112
+ status: 'progress',
113
+ output: {
114
+ completed: i + 1,
115
+ total: texts.length,
116
+ currentText: text,
117
+ embedding: embedding
118
+ }
119
+ })
120
+ } catch (error) {
121
+ embeddings.push({
122
+ text: text,
123
+ embedding: null,
124
+ error: error.message,
125
+ index: i
126
+ })
127
+
128
+ self.postMessage({
129
+ status: 'progress',
130
+ output: {
131
+ completed: i + 1,
132
+ total: texts.length,
133
+ currentText: text,
134
+ error: error.message
135
+ }
136
+ })
137
+ }
138
+ }
139
+
140
+ self.postMessage({
141
+ status: 'output',
142
+ output: {
143
+ embeddings: embeddings,
144
+ completed: true
145
+ }
146
+ })
147
+
148
+ self.postMessage({ status: 'ready' })
149
+ }
150
+ } catch (error) {
151
+ self.postMessage({
152
+ status: 'error',
153
+ output: error.message || 'An error occurred during feature extraction'
154
+ })
155
+ }
156
+ })
src/App.tsx CHANGED
@@ -6,6 +6,7 @@ import Header from './Header'
6
  import { useModel } from './contexts/ModelContext'
7
  import { getModelsByPipeline } from './lib/huggingface'
8
  import TextGeneration from './components/TextGeneration'
 
9
  import Sidebar from './components/Sidebar'
10
  import ModelReadme from './components/ModelReadme'
11
  import { PipelineLayout } from './components/PipelineLayout'
@@ -70,6 +71,7 @@ function App() {
70
  )}
71
  {pipeline === 'text-classification' && <TextClassification />}
72
  {pipeline === 'text-generation' && <TextGeneration />}
 
73
  </div>
74
  </div>
75
  </main>
 
6
  import { useModel } from './contexts/ModelContext'
7
  import { getModelsByPipeline } from './lib/huggingface'
8
  import TextGeneration from './components/TextGeneration'
9
+ import FeatureExtraction from './components/FeatureExtraction'
10
  import Sidebar from './components/Sidebar'
11
  import ModelReadme from './components/ModelReadme'
12
  import { PipelineLayout } from './components/PipelineLayout'
 
71
  )}
72
  {pipeline === 'text-classification' && <TextClassification />}
73
  {pipeline === 'text-generation' && <TextGeneration />}
74
+ {pipeline === 'feature-extraction' && <FeatureExtraction />}
75
  </div>
76
  </div>
77
  </main>
src/components/FeatureExtraction.tsx ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useRef, useEffect, useCallback } from 'react'
2
+ import { Plus, Trash2, Loader2, X, Eye, EyeOff } from 'lucide-react'
3
+ import {
4
+ EmbeddingExample,
5
+ FeatureExtractionWorkerInput,
6
+ WorkerMessage
7
+ } from '../types'
8
+ import { useModel } from '../contexts/ModelContext'
9
+ import { useFeatureExtraction } from '../contexts/FeatureExtractionContext'
10
+
11
+ interface Point2D {
12
+ x: number
13
+ y: number
14
+ id: string
15
+ text: string
16
+ similarity?: number
17
+ }
18
+
19
+ // Sample data for quick testing
20
+ const SAMPLE_TEXTS = [
21
+ 'The cat sat on the mat',
22
+ 'A feline rested on the carpet',
23
+ 'I love programming in JavaScript',
24
+ 'JavaScript development is my passion',
25
+ 'The weather is beautiful today',
26
+ "It's a sunny and warm day outside",
27
+ 'Machine learning is transforming technology',
28
+ 'AI and deep learning are revolutionizing computing',
29
+ 'I enjoy reading books in the evening',
30
+ 'Pizza is one of my favorite foods'
31
+ ]
32
+
33
+ function FeatureExtraction() {
34
+ const {
35
+ examples,
36
+ selectedExample,
37
+ setSelectedExample,
38
+ similarities,
39
+ addExample,
40
+ removeExample,
41
+ updateExample,
42
+ calculateSimilarities,
43
+ clearExamples,
44
+ config
45
+ } = useFeatureExtraction()
46
+
47
+ const [newExampleText, setNewExampleText] = useState<string>('')
48
+ const [isExtracting, setIsExtracting] = useState<boolean>(false)
49
+ const [showVisualization, setShowVisualization] = useState<boolean>(true)
50
+ const [progress, setProgress] = useState<{
51
+ completed: number
52
+ total: number
53
+ } | null>(null)
54
+
55
+ const {
56
+ activeWorker,
57
+ status,
58
+ modelInfo,
59
+ hasBeenLoaded,
60
+ selectedQuantization
61
+ } = useModel()
62
+
63
+ const chartRef = useRef<SVGSVGElement>(null)
64
+
65
+ // PCA reduction to 2D for visualization
66
+ const reduceTo2D = useCallback(
67
+ (embeddings: number[][]): Point2D[] => {
68
+ if (embeddings.length === 0) return []
69
+
70
+ // For simplicity, just use first 2 dimensions if available, or random projection
71
+ const points: Point2D[] = examples
72
+ .filter((ex) => ex.embedding)
73
+ .map((example, i) => {
74
+ const emb = example.embedding!
75
+ let x, y
76
+
77
+ if (emb.length >= 2) {
78
+ x = emb[0]
79
+ y = emb[1]
80
+ } else {
81
+ // Simple hash-based positioning for 1D embeddings
82
+ const hash = example.text.split('').reduce((a, b) => {
83
+ a = (a << 5) - a + b.charCodeAt(0)
84
+ return a & a
85
+ }, 0)
86
+ x = Math.sin(hash) * 100
87
+ y = Math.cos(hash) * 100
88
+ }
89
+
90
+ return {
91
+ x,
92
+ y,
93
+ id: example.id,
94
+ text: example.text,
95
+ similarity: similarities.find((s) => s.exampleId === example.id)
96
+ ?.similarity
97
+ }
98
+ })
99
+
100
+ // Normalize points to fit in chart
101
+ if (points.length > 0) {
102
+ const minX = Math.min(...points.map((p) => p.x))
103
+ const maxX = Math.max(...points.map((p) => p.x))
104
+ const minY = Math.min(...points.map((p) => p.y))
105
+ const maxY = Math.max(...points.map((p) => p.y))
106
+
107
+ const rangeX = maxX - minX || 1
108
+ const rangeY = maxY - minY || 1
109
+
110
+ return points.map((p) => ({
111
+ ...p,
112
+ x: ((p.x - minX) / rangeX) * 300 + 50,
113
+ y: ((p.y - minY) / rangeY) * 200 + 50
114
+ }))
115
+ }
116
+
117
+ return points
118
+ },
119
+ [examples, similarities]
120
+ )
121
+
122
+ const extractEmbeddings = useCallback(
123
+ async (textsToExtract: string[]) => {
124
+ if (!modelInfo || !activeWorker || textsToExtract.length === 0) return
125
+
126
+ setIsExtracting(true)
127
+ setProgress({ completed: 0, total: textsToExtract.length })
128
+
129
+ const message: FeatureExtractionWorkerInput = {
130
+ type: 'extract',
131
+ texts: textsToExtract,
132
+ model: modelInfo.id,
133
+ dtype: selectedQuantization ?? 'fp32',
134
+ config
135
+ }
136
+
137
+ activeWorker.postMessage(message)
138
+ },
139
+ [modelInfo, activeWorker, selectedQuantization, config]
140
+ )
141
+
142
+ const handleAddExample = useCallback(() => {
143
+ if (!newExampleText.trim()) return
144
+
145
+ addExample(newExampleText)
146
+ setNewExampleText('')
147
+ }, [newExampleText, addExample])
148
+
149
+ const handleExtractAll = useCallback(() => {
150
+ const textsToExtract = examples
151
+ .filter((ex) => !ex.embedding && !ex.isLoading)
152
+ .map((ex) => ex.text)
153
+
154
+ if (textsToExtract.length > 0) {
155
+ extractEmbeddings(textsToExtract)
156
+ }
157
+ }, [examples, extractEmbeddings])
158
+
159
+ const handleSelectExample = useCallback(
160
+ (example: EmbeddingExample) => {
161
+ setSelectedExample(example)
162
+ if (example.embedding) {
163
+ calculateSimilarities(example)
164
+ }
165
+ },
166
+ [setSelectedExample, calculateSimilarities]
167
+ )
168
+
169
+ const handleLoadSampleData = useCallback(() => {
170
+ SAMPLE_TEXTS.forEach((text) => addExample(text))
171
+ }, [addExample])
172
+
173
+ useEffect(() => {
174
+ if (!activeWorker) return
175
+
176
+ const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
177
+ const { status, output } = e.data
178
+
179
+ if (status === 'progress' && output) {
180
+ setProgress({ completed: output.completed, total: output.total })
181
+
182
+ if (output.embedding && output.currentText) {
183
+ const example = examples.find((ex) => ex.text === output.currentText)
184
+ if (example) {
185
+ updateExample(example.id, {
186
+ embedding: output.embedding,
187
+ isLoading: false
188
+ })
189
+ }
190
+ }
191
+ } else if (status === 'output' && output?.embeddings) {
192
+ output.embeddings.forEach((result: any) => {
193
+ const example = examples.find((ex) => ex.text === result.text)
194
+ if (example) {
195
+ updateExample(example.id, {
196
+ embedding: result.embedding,
197
+ isLoading: false
198
+ })
199
+ }
200
+ })
201
+ console.log({ examples })
202
+ setIsExtracting(false)
203
+ setProgress(null)
204
+ } else if (status === 'error') {
205
+ setIsExtracting(false)
206
+ setProgress(null)
207
+ }
208
+ }
209
+
210
+ activeWorker.addEventListener('message', onMessageReceived)
211
+ return () => activeWorker.removeEventListener('message', onMessageReceived)
212
+ }, [activeWorker, examples, updateExample])
213
+
214
+ const handleKeyPress = (e: React.KeyboardEvent) => {
215
+ if (e.key === 'Enter' && !e.shiftKey) {
216
+ e.preventDefault()
217
+ handleAddExample()
218
+ }
219
+ }
220
+
221
+ const points2D = reduceTo2D(
222
+ examples.filter((ex) => ex.embedding).map((ex) => ex.embedding!)
223
+ )
224
+ const busy = status !== 'ready' || isExtracting
225
+
226
+ return (
227
+ <div className="flex flex-col h-[70vh] max-h-[100vh] w-full p-4">
228
+ <div className="flex items-center justify-between mb-4">
229
+ <h1 className="text-2xl font-bold">Feature Extraction (Embeddings)</h1>
230
+ <div className="flex gap-2">
231
+ <button
232
+ onClick={handleLoadSampleData}
233
+ disabled={!hasBeenLoaded || isExtracting}
234
+ className="px-3 py-2 bg-purple-100 hover:bg-purple-200 disabled:bg-gray-100 disabled:cursor-not-allowed rounded-lg transition-colors text-sm"
235
+ title="Load Sample Data"
236
+ >
237
+ Load Samples
238
+ </button>
239
+ <button
240
+ onClick={() => setShowVisualization(!showVisualization)}
241
+ className="p-2 bg-blue-100 hover:bg-blue-200 rounded-lg transition-colors"
242
+ title={
243
+ showVisualization ? 'Hide Visualization' : 'Show Visualization'
244
+ }
245
+ >
246
+ {showVisualization ? (
247
+ <EyeOff className="w-4 h-4" />
248
+ ) : (
249
+ <Eye className="w-4 h-4" />
250
+ )}
251
+ </button>
252
+ <button
253
+ onClick={clearExamples}
254
+ className="p-2 bg-red-100 hover:bg-red-200 rounded-lg transition-colors"
255
+ title="Clear All Examples"
256
+ >
257
+ <Trash2 className="w-4 h-4" />
258
+ </button>
259
+ </div>
260
+ </div>
261
+
262
+ <div className="flex flex-col lg:flex-row gap-4 flex-1">
263
+ {/* Left Panel - Examples */}
264
+ <div className="lg:w-1/2 flex flex-col">
265
+ {/* Add Example */}
266
+ <div className="mb-4">
267
+ <label className="block text-sm font-medium text-gray-700 mb-2">
268
+ Add Text Examples:
269
+ </label>
270
+ <div className="flex gap-2">
271
+ <textarea
272
+ value={newExampleText}
273
+ onChange={(e) => setNewExampleText(e.target.value)}
274
+ onKeyPress={handleKeyPress}
275
+ placeholder="Enter text to get embeddings... (Press Enter to add)"
276
+ className="flex-1 p-3 border border-gray-300 rounded-lg resize-none focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 disabled:bg-gray-100 disabled:cursor-not-allowed"
277
+ rows={2}
278
+ disabled={!hasBeenLoaded || isExtracting}
279
+ />
280
+ <button
281
+ onClick={handleAddExample}
282
+ disabled={!newExampleText.trim() || !hasBeenLoaded}
283
+ className="px-4 py-2 bg-blue-500 hover:bg-blue-600 disabled:bg-gray-300 disabled:cursor-not-allowed text-white rounded-lg transition-colors"
284
+ >
285
+ <Plus className="w-4 h-4" />
286
+ </button>
287
+ </div>
288
+ </div>
289
+
290
+ {/* Extract Button */}
291
+ {examples.some((ex) => !ex.embedding) && (
292
+ <div className="mb-4">
293
+ <button
294
+ onClick={handleExtractAll}
295
+ disabled={busy || !hasBeenLoaded}
296
+ className="px-6 py-2 bg-green-500 hover:bg-green-600 disabled:bg-gray-300 disabled:cursor-not-allowed text-white rounded-lg transition-colors flex items-center gap-2"
297
+ >
298
+ {isExtracting ? (
299
+ <>
300
+ <Loader2 className="w-4 h-4 animate-spin" />
301
+ Extracting...{' '}
302
+ {progress && `(${progress.completed}/${progress.total})`}
303
+ </>
304
+ ) : (
305
+ 'Extract Embeddings'
306
+ )}
307
+ </button>
308
+ </div>
309
+ )}
310
+
311
+ {/* Examples List */}
312
+ <div className="flex-1 overflow-y-auto border border-gray-300 rounded-lg bg-white">
313
+ <div className="p-4">
314
+ <h3 className="text-sm font-medium text-gray-700 mb-3">
315
+ Examples ({examples.length})
316
+ </h3>
317
+ {examples.length === 0 ? (
318
+ <div className="text-gray-500 italic text-center py-8">
319
+ No examples added yet. Add some text above to get started.
320
+ </div>
321
+ ) : (
322
+ <div className="space-y-2">
323
+ {examples.map((example) => (
324
+ <div
325
+ key={example.id}
326
+ className={`p-3 border rounded-lg cursor-pointer transition-colors ${
327
+ selectedExample?.id === example.id
328
+ ? 'border-blue-500 bg-blue-50'
329
+ : 'border-gray-200 hover:border-gray-300'
330
+ }`}
331
+ onClick={() => handleSelectExample(example)}
332
+ >
333
+ <div className="flex justify-between items-start">
334
+ <div className="flex-1 min-w-0">
335
+ <div className="text-sm text-gray-800 break-words">
336
+ {example.text}
337
+ </div>
338
+ <div className="flex items-center gap-2 mt-1">
339
+ {example.isLoading ? (
340
+ <div className="flex items-center gap-1 text-xs text-blue-600">
341
+ <Loader2 className="w-3 h-3 animate-spin" />
342
+ Extracting...
343
+ </div>
344
+ ) : example.embedding ? (
345
+ <div className="text-xs text-green-600">
346
+ ✓ Embedding ready ({example.embedding.length}D)
347
+ </div>
348
+ ) : (
349
+ <div className="text-xs text-gray-500">
350
+ No embedding
351
+ </div>
352
+ )}
353
+ {selectedExample?.id === example.id &&
354
+ similarities.length > 0 && (
355
+ <div className="text-xs text-blue-600">
356
+ Selected
357
+ </div>
358
+ )}
359
+ </div>
360
+ </div>
361
+ <button
362
+ onClick={(e) => {
363
+ e.stopPropagation()
364
+ removeExample(example.id)
365
+ }}
366
+ className="ml-2 p-1 text-red-500 hover:text-red-700 transition-colors"
367
+ >
368
+ <X className="w-3 h-3" />
369
+ </button>
370
+ </div>
371
+ </div>
372
+ ))}
373
+ </div>
374
+ )}
375
+ </div>
376
+ </div>
377
+ </div>
378
+
379
+ {/* Right Panel - Visualization and Similarities */}
380
+ <div className="lg:w-1/2 flex flex-col">
381
+ {showVisualization && (
382
+ <div className="mb-4">
383
+ <h3 className="text-sm font-medium text-gray-700 mb-2">
384
+ 2D Visualization
385
+ </h3>
386
+ <div className="border border-gray-300 rounded-lg bg-white p-4">
387
+ <svg
388
+ ref={chartRef}
389
+ width="100%"
390
+ height="300"
391
+ viewBox="0 0 400 300"
392
+ className="border border-gray-100"
393
+ >
394
+ {points2D.map((point) => {
395
+ const isSelected = selectedExample?.id === point.id
396
+ const similarity = point.similarity
397
+
398
+ // Color based on similarity to selected example
399
+ let fillColor = '#6b7280' // default gray
400
+ if (isSelected) {
401
+ fillColor = '#3b82f6' // blue for selected
402
+ } else if (similarity !== undefined) {
403
+ if (similarity > 0.8)
404
+ fillColor = '#10b981' // green for high similarity
405
+ else if (similarity > 0.5)
406
+ fillColor = '#f59e0b' // yellow for medium similarity
407
+ else fillColor = '#ef4444' // red for low similarity
408
+ }
409
+
410
+ return (
411
+ <g key={point.id}>
412
+ <circle
413
+ cx={point.x}
414
+ cy={point.y}
415
+ r={isSelected ? 8 : 5}
416
+ fill={fillColor}
417
+ stroke="white"
418
+ strokeWidth="2"
419
+ className="cursor-pointer hover:stroke-4 transition-all duration-200"
420
+ onClick={() => {
421
+ const example = examples.find(
422
+ (ex) => ex.id === point.id
423
+ )
424
+ if (example) handleSelectExample(example)
425
+ }}
426
+ style={{
427
+ filter: isSelected
428
+ ? 'drop-shadow(0 0 6px rgba(59, 130, 246, 0.6))'
429
+ : 'none'
430
+ }}
431
+ />
432
+ <text
433
+ x={point.x + 10}
434
+ y={point.y + 4}
435
+ fontSize="9"
436
+ fill="#374151"
437
+ className="pointer-events-none font-medium"
438
+ style={{
439
+ textShadow: '1px 1px 2px rgba(255,255,255,0.8)'
440
+ }}
441
+ >
442
+ {point.text.substring(0, 15)}...
443
+ </text>
444
+ {similarity !== undefined && (
445
+ <text
446
+ x={point.x}
447
+ y={point.y - 10}
448
+ fontSize="8"
449
+ fill={fillColor}
450
+ className="pointer-events-none font-bold text-center"
451
+ textAnchor="middle"
452
+ >
453
+ {(similarity * 100).toFixed(0)}%
454
+ </text>
455
+ )}
456
+ </g>
457
+ )
458
+ })}
459
+ </svg>
460
+ {points2D.length === 0 && (
461
+ <div className="text-center text-gray-500 py-8">
462
+ Extract embeddings to see visualization
463
+ </div>
464
+ )}
465
+ {points2D.length > 0 && (
466
+ <div className="mt-3 p-3 bg-gray-50 rounded-lg">
467
+ <h4 className="text-xs font-medium text-gray-700 mb-2">
468
+ Legend:
469
+ </h4>
470
+ <div className="flex flex-wrap gap-3 text-xs">
471
+ <div className="flex items-center gap-1">
472
+ <div className="w-3 h-3 rounded-full bg-blue-500"></div>
473
+ <span>Selected</span>
474
+ </div>
475
+ <div className="flex items-center gap-1">
476
+ <div className="w-3 h-3 rounded-full bg-green-500"></div>
477
+ <span>High similarity (&gt;80%)</span>
478
+ </div>
479
+ <div className="flex items-center gap-1">
480
+ <div className="w-3 h-3 rounded-full bg-yellow-500"></div>
481
+ <span>Medium similarity (50-80%)</span>
482
+ </div>
483
+ <div className="flex items-center gap-1">
484
+ <div className="w-3 h-3 rounded-full bg-red-500"></div>
485
+ <span>Low similarity (&lt;50%)</span>
486
+ </div>
487
+ <div className="flex items-center gap-1">
488
+ <div className="w-3 h-3 rounded-full bg-gray-500"></div>
489
+ <span>Not compared</span>
490
+ </div>
491
+ </div>
492
+ </div>
493
+ )}
494
+ </div>
495
+ </div>
496
+ )}
497
+
498
+ {/* Similarity Results */}
499
+ <div className="flex-1 overflow-y-auto border border-gray-300 rounded-lg bg-white">
500
+ <div className="p-4">
501
+ <h3 className="text-sm font-medium text-gray-700 mb-3">
502
+ Cosine Similarities
503
+ {selectedExample &&
504
+ ` (vs "${selectedExample.text.substring(0, 30)}...")`}
505
+ </h3>
506
+ {!selectedExample ? (
507
+ <div className="text-gray-500 italic text-center py-8">
508
+ Select an example to see similarities
509
+ </div>
510
+ ) : similarities.length === 0 ? (
511
+ <div className="text-gray-500 italic text-center py-8">
512
+ No other examples with embeddings to compare
513
+ </div>
514
+ ) : (
515
+ <div className="space-y-2">
516
+ {similarities.map((sim) => {
517
+ const example = examples.find(
518
+ (ex) => ex.id === sim.exampleId
519
+ )
520
+ if (!example) return null
521
+
522
+ const similarityPercent = (sim.similarity * 100).toFixed(1)
523
+ const color =
524
+ sim.similarity > 0.8
525
+ ? 'text-green-600'
526
+ : sim.similarity > 0.5
527
+ ? 'text-yellow-600'
528
+ : 'text-red-500'
529
+
530
+ return (
531
+ <div
532
+ key={sim.exampleId}
533
+ className="p-3 border border-gray-200 rounded-lg hover:bg-gray-50 transition-colors"
534
+ >
535
+ <div className="flex justify-between items-start">
536
+ <div className="flex-1 min-w-0">
537
+ <div className="text-sm text-gray-800 break-words">
538
+ {example.text}
539
+ </div>
540
+ </div>
541
+ <div className={`ml-2 text-sm font-medium ${color}`}>
542
+ {similarityPercent}%
543
+ </div>
544
+ </div>
545
+ <div className="mt-2">
546
+ <div className="w-full bg-gray-200 rounded-full h-2">
547
+ <div
548
+ className={`h-2 rounded-full transition-all duration-300 ${
549
+ sim.similarity > 0.8
550
+ ? 'bg-green-500'
551
+ : sim.similarity > 0.5
552
+ ? 'bg-yellow-500'
553
+ : 'bg-red-500'
554
+ }`}
555
+ style={{
556
+ width: `${Math.max(sim.similarity * 100, 5)}%`
557
+ }}
558
+ />
559
+ </div>
560
+ </div>
561
+ </div>
562
+ )
563
+ })}
564
+ </div>
565
+ )}
566
+ </div>
567
+ </div>
568
+ </div>
569
+ </div>
570
+
571
+ {!hasBeenLoaded && (
572
+ <div className="text-center text-gray-500 text-sm mt-2">
573
+ Please load a feature extraction model first to start generating
574
+ embeddings
575
+ </div>
576
+ )}
577
+
578
+ {hasBeenLoaded && examples.length === 0 && (
579
+ <div className="text-center text-blue-600 text-sm mt-2">
580
+ 💡 Tip: Click "Load Samples" to try with example texts, or add your
581
+ own text above
582
+ </div>
583
+ )}
584
+ </div>
585
+ )
586
+ }
587
+
588
+ export default FeatureExtraction
src/components/FeatureExtractionConfig.tsx ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from 'react'
2
+ import { useFeatureExtraction } from '../contexts/FeatureExtractionContext'
3
+
4
+ const FeatureExtractionConfig = () => {
5
+ const { config, setConfig } = useFeatureExtraction()
6
+
7
+ return (
8
+ <div className="space-y-4">
9
+ <h3 className="text-lg font-semibold text-gray-900">
10
+ Feature Extraction Settings
11
+ </h3>
12
+
13
+ <div className="space-y-3">
14
+ <div>
15
+ <label className="block text-sm font-medium text-gray-700 mb-1">
16
+ Pooling Strategy
17
+ </label>
18
+ <select
19
+ value={config.pooling}
20
+ onChange={(e) => setConfig(prev => ({
21
+ ...prev,
22
+ pooling: e.target.value as 'mean' | 'cls' | 'max'
23
+ }))}
24
+ className="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-blue-500 text-sm"
25
+ >
26
+ <option value="mean">Mean Pooling</option>
27
+ <option value="cls">CLS Token</option>
28
+ <option value="max">Max Pooling</option>
29
+ </select>
30
+ <p className="text-xs text-gray-500 mt-1">
31
+ How to aggregate token embeddings into sentence embeddings
32
+ </p>
33
+ </div>
34
+
35
+ <div>
36
+ <label className="flex items-center space-x-2">
37
+ <input
38
+ type="checkbox"
39
+ checked={config.normalize}
40
+ onChange={(e) => setConfig(prev => ({
41
+ ...prev,
42
+ normalize: e.target.checked
43
+ }))}
44
+ className="rounded border-gray-300 text-blue-600 shadow-sm focus:border-blue-300 focus:ring focus:ring-blue-200 focus:ring-opacity-50"
45
+ />
46
+ <span className="text-sm font-medium text-gray-700">
47
+ Normalize Embeddings
48
+ </span>
49
+ </label>
50
+ <p className="text-xs text-gray-500 mt-1 ml-6">
51
+ L2 normalize embeddings for better similarity calculations
52
+ </p>
53
+ </div>
54
+ </div>
55
+
56
+ <div className="pt-2 border-t border-gray-200">
57
+ <div className="text-xs text-gray-500">
58
+ <p className="mb-1">
59
+ <strong>Mean Pooling:</strong> Average all token embeddings
60
+ </p>
61
+ <p className="mb-1">
62
+ <strong>CLS Token:</strong> Use the [CLS] token embedding (if available)
63
+ </p>
64
+ <p>
65
+ <strong>Max Pooling:</strong> Take element-wise maximum across tokens
66
+ </p>
67
+ </div>
68
+ </div>
69
+ </div>
70
+ )
71
+ }
72
+
73
+ export default FeatureExtractionConfig
src/components/PipelineLayout.tsx CHANGED
@@ -1,5 +1,6 @@
1
  import { useModel } from '../contexts/ModelContext'
2
  import { TextGenerationProvider } from '../contexts/TextGenerationContext'
 
3
 
4
  export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
5
  const { pipeline } = useModel()
@@ -8,6 +9,9 @@ export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
8
  case 'text-generation':
9
  return <TextGenerationProvider>{children}</TextGenerationProvider>
10
 
 
 
 
11
  // case 'zero-shot-classification':
12
  // return <ZeroShotProvider>{children}</ZeroShotProvider>;
13
 
 
1
  import { useModel } from '../contexts/ModelContext'
2
  import { TextGenerationProvider } from '../contexts/TextGenerationContext'
3
+ import { FeatureExtractionProvider } from '../contexts/FeatureExtractionContext'
4
 
5
  export const PipelineLayout = ({ children }: { children: React.ReactNode }) => {
6
  const { pipeline } = useModel()
 
9
  case 'text-generation':
10
  return <TextGenerationProvider>{children}</TextGenerationProvider>
11
 
12
+ case 'feature-extraction':
13
+ return <FeatureExtractionProvider>{children}</FeatureExtractionProvider>
14
+
15
  // case 'zero-shot-classification':
16
  // return <ZeroShotProvider>{children}</ZeroShotProvider>;
17
 
src/components/Sidebar.tsx CHANGED
@@ -4,6 +4,7 @@ import ModelSelector from './ModelSelector'
4
  import ModelInfo from './ModelInfo'
5
  import { useModel } from '../contexts/ModelContext'
6
  import TextGenerationConfig from './TextGenerationConfig'
 
7
 
8
  interface SidebarProps {
9
  isOpen: boolean
@@ -75,6 +76,7 @@ const Sidebar = ({ isOpen, onClose }: SidebarProps) => {
75
 
76
  <hr className="border-gray-200" />
77
  {pipeline === 'text-generation' && <TextGenerationConfig />}
 
78
  </div>
79
  </div>
80
  </div>
 
4
  import ModelInfo from './ModelInfo'
5
  import { useModel } from '../contexts/ModelContext'
6
  import TextGenerationConfig from './TextGenerationConfig'
7
+ import FeatureExtractionConfig from './FeatureExtractionConfig'
8
 
9
  interface SidebarProps {
10
  isOpen: boolean
 
76
 
77
  <hr className="border-gray-200" />
78
  {pipeline === 'text-generation' && <TextGenerationConfig />}
79
+ {pipeline === 'feature-extraction' && <FeatureExtractionConfig />}
80
  </div>
81
  </div>
82
  </div>
src/contexts/FeatureExtractionContext.tsx ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { createContext, useContext, useState, useCallback } from 'react'
2
+ import { EmbeddingExample, SimilarityResult } from '../types'
3
+
4
+ interface FeatureExtractionConfig {
5
+ pooling: 'mean' | 'cls' | 'max'
6
+ normalize: boolean
7
+ }
8
+
9
+ interface FeatureExtractionContextType {
10
+ examples: EmbeddingExample[]
11
+ setExamples: React.Dispatch<React.SetStateAction<EmbeddingExample[]>>
12
+ selectedExample: EmbeddingExample | null
13
+ setSelectedExample: React.Dispatch<React.SetStateAction<EmbeddingExample | null>>
14
+ similarities: SimilarityResult[]
15
+ setSimilarities: React.Dispatch<React.SetStateAction<SimilarityResult[]>>
16
+ config: FeatureExtractionConfig
17
+ setConfig: React.Dispatch<React.SetStateAction<FeatureExtractionConfig>>
18
+ addExample: (text: string) => void
19
+ removeExample: (id: string) => void
20
+ updateExample: (id: string, updates: Partial<EmbeddingExample>) => void
21
+ calculateSimilarities: (targetExample: EmbeddingExample) => void
22
+ clearExamples: () => void
23
+ }
24
+
25
+ const FeatureExtractionContext = createContext<FeatureExtractionContextType | undefined>(undefined)
26
+
27
+ export const useFeatureExtraction = () => {
28
+ const context = useContext(FeatureExtractionContext)
29
+ if (!context) {
30
+ throw new Error('useFeatureExtraction must be used within a FeatureExtractionProvider')
31
+ }
32
+ return context
33
+ }
34
+
35
+ // Cosine similarity calculation
36
+ const cosineSimilarity = (a: number[], b: number[]): number => {
37
+ if (a.length !== b.length) {
38
+ throw new Error('Vectors must have the same length')
39
+ }
40
+
41
+ let dotProduct = 0
42
+ let normA = 0
43
+ let normB = 0
44
+
45
+ for (let i = 0; i < a.length; i++) {
46
+ dotProduct += a[i] * b[i]
47
+ normA += a[i] * a[i]
48
+ normB += b[i] * b[i]
49
+ }
50
+
51
+ normA = Math.sqrt(normA)
52
+ normB = Math.sqrt(normB)
53
+
54
+ if (normA === 0 || normB === 0) {
55
+ return 0
56
+ }
57
+
58
+ return dotProduct / (normA * normB)
59
+ }
60
+
61
+ export const FeatureExtractionProvider: React.FC<{ children: React.ReactNode }> = ({ children }) => {
62
+ const [examples, setExamples] = useState<EmbeddingExample[]>([])
63
+ const [selectedExample, setSelectedExample] = useState<EmbeddingExample | null>(null)
64
+ const [similarities, setSimilarities] = useState<SimilarityResult[]>([])
65
+ const [config, setConfig] = useState<FeatureExtractionConfig>({
66
+ pooling: 'mean',
67
+ normalize: true
68
+ })
69
+
70
+ const addExample = useCallback((text: string) => {
71
+ const newExample: EmbeddingExample = {
72
+ id: Date.now().toString() + Math.random().toString(36).substr(2, 9),
73
+ text: text.trim(),
74
+ embedding: undefined,
75
+ isLoading: false
76
+ }
77
+ setExamples(prev => [...prev, newExample])
78
+ }, [])
79
+
80
+ const removeExample = useCallback((id: string) => {
81
+ setExamples(prev => prev.filter(example => example.id !== id))
82
+ if (selectedExample?.id === id) {
83
+ setSelectedExample(null)
84
+ setSimilarities([])
85
+ }
86
+ }, [selectedExample])
87
+
88
+ const updateExample = useCallback((id: string, updates: Partial<EmbeddingExample>) => {
89
+ setExamples(prev => prev.map(example =>
90
+ example.id === id ? { ...example, ...updates } : example
91
+ ))
92
+ }, [])
93
+
94
+ const calculateSimilarities = useCallback((targetExample: EmbeddingExample) => {
95
+ if (!targetExample.embedding) {
96
+ setSimilarities([])
97
+ return
98
+ }
99
+
100
+ const newSimilarities: SimilarityResult[] = examples
101
+ .filter(example => example.id !== targetExample.id && example.embedding)
102
+ .map(example => ({
103
+ exampleId: example.id,
104
+ similarity: cosineSimilarity(targetExample.embedding!, example.embedding!)
105
+ }))
106
+ .sort((a, b) => b.similarity - a.similarity)
107
+
108
+ setSimilarities(newSimilarities)
109
+ }, [examples])
110
+
111
+ const clearExamples = useCallback(() => {
112
+ setExamples([])
113
+ setSelectedExample(null)
114
+ setSimilarities([])
115
+ }, [])
116
+
117
+ const value: FeatureExtractionContextType = {
118
+ examples,
119
+ setExamples,
120
+ selectedExample,
121
+ setSelectedExample,
122
+ similarities,
123
+ setSimilarities,
124
+ config,
125
+ setConfig,
126
+ addExample,
127
+ removeExample,
128
+ updateExample,
129
+ calculateSimilarities,
130
+ clearExamples
131
+ }
132
+
133
+ return (
134
+ <FeatureExtractionContext.Provider value={value}>
135
+ {children}
136
+ </FeatureExtractionContext.Provider>
137
+ )
138
+ }
src/lib/workerManager.ts CHANGED
@@ -14,6 +14,9 @@ export const getWorker = (pipeline: string) => {
14
  case 'text-generation':
15
  workerUrl = `${process.env.PUBLIC_URL}/workers/text-generation.js`
16
  break
 
 
 
17
  default:
18
  return null
19
  }
 
14
  case 'text-generation':
15
  workerUrl = `${process.env.PUBLIC_URL}/workers/text-generation.js`
16
  break
17
+ case 'feature-extraction':
18
+ workerUrl = `${process.env.PUBLIC_URL}/workers/feature-extraction.js`
19
+ break
20
  default:
21
  return null
22
  }
src/types.ts CHANGED
@@ -24,6 +24,7 @@ export type WorkerStatus =
24
  | 'ready'
25
  | 'output'
26
  | 'loading'
 
27
  | 'error'
28
  | 'disposed'
29
 
@@ -63,6 +64,29 @@ export interface TextGenerationWorkerInput {
63
  dtype: QuantizationType
64
  }
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  const q8Types = ['q8', 'int8', 'bnb8', 'uint8'] as const
67
  const q4Types = ['q4', 'bnb4', 'q4f16'] as const
68
  const fp16Types = ['fp16'] as const
 
24
  | 'ready'
25
  | 'output'
26
  | 'loading'
27
+ | 'progress'
28
  | 'error'
29
  | 'disposed'
30
 
 
64
  dtype: QuantizationType
65
  }
66
 
67
+ export interface FeatureExtractionWorkerInput {
68
+ type: 'extract' | 'load'
69
+ texts?: string[]
70
+ model: string
71
+ dtype: QuantizationType
72
+ config: {
73
+ pooling: 'mean' | 'cls' | 'max'
74
+ normalize: boolean
75
+ }
76
+ }
77
+
78
+ export interface EmbeddingExample {
79
+ id: string
80
+ text: string
81
+ embedding?: number[]
82
+ isLoading?: boolean
83
+ }
84
+
85
+ export interface SimilarityResult {
86
+ exampleId: string
87
+ similarity: number
88
+ }
89
+
90
  const q8Types = ['q8', 'int8', 'bnb8', 'uint8'] as const
91
  const q4Types = ['q4', 'bnb4', 'q4f16'] as const
92
  const fp16Types = ['fp16'] as const