Vokturz commited on
Commit
ca67cfa
·
1 Parent(s): de6e73e

Refactor model configuration to use config object

Browse files
public/workers/image-classification.js CHANGED
@@ -62,7 +62,7 @@ class MyImageClassificationPipeline {
62
  // Listen for messages from the main thread
63
  self.addEventListener('message', async (event) => {
64
  try {
65
- const { type, image, model, dtype, topK = 1 } = event.data
66
 
67
  if (!model) {
68
  self.postMessage({
@@ -100,9 +100,7 @@ self.addEventListener('message', async (event) => {
100
 
101
  try {
102
  // Run classification
103
- const output = await classifier(image, {
104
- top_k: topK
105
- })
106
 
107
  // Format predictions
108
  const predictions = output.map((item) => ({
 
62
  // Listen for messages from the main thread
63
  self.addEventListener('message', async (event) => {
64
  try {
65
+ const { type, image, model, dtype, config } = event.data
66
 
67
  if (!model) {
68
  self.postMessage({
 
100
 
101
  try {
102
  // Run classification
103
+ const output = await classifier(image, config)
 
 
104
 
105
  // Format predictions
106
  const predictions = output.map((item) => ({
public/workers/text-generation.js CHANGED
@@ -49,20 +49,8 @@ class MyTextGenerationPipeline {
49
  // Listen for messages from the main thread
50
  self.addEventListener('message', async (event) => {
51
  try {
52
- const {
53
- type,
54
- model,
55
- dtype,
56
- messages,
57
- prompt,
58
- hasChatTemplate,
59
- temperature,
60
- max_new_tokens,
61
- top_p,
62
- top_k,
63
- do_sample,
64
- stop_words
65
- } = event.data
66
 
67
  if (type === 'stop') {
68
  MyTextGenerationPipeline.stopGeneration()
@@ -108,12 +96,11 @@ self.addEventListener('message', async (event) => {
108
  }
109
 
110
  const options = {
111
- max_new_tokens: max_new_tokens || 100,
112
- temperature: temperature || 0.7,
113
- do_sample: do_sample !== false,
114
- ...(top_p && { top_p }),
115
- ...(top_k && { top_k }),
116
- ...(stop_words && stop_words.length > 0 && { stop_words })
117
  }
118
 
119
  // Create an AbortController for this generation
 
49
  // Listen for messages from the main thread
50
  self.addEventListener('message', async (event) => {
51
  try {
52
+ const { type, model, dtype, messages, prompt, hasChatTemplate, config } =
53
+ event.data
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  if (type === 'stop') {
56
  MyTextGenerationPipeline.stopGeneration()
 
96
  }
97
 
98
  const options = {
99
+ max_new_tokens: config.max_new_tokens || 100,
100
+ temperature: config.temperature || 0.7,
101
+ do_sample: config.do_sample !== false,
102
+ ...(config.top_p && { top_p }),
103
+ ...(config.top_k && { top_k })
 
104
  }
105
 
106
  // Create an AbortController for this generation
src/components/pipelines/ImageClassification.tsx CHANGED
@@ -45,7 +45,7 @@ function ImageClassification() {
45
  removeExample,
46
  updateExample,
47
  clearExamples,
48
- topK
49
  } = useImageClassification()
50
 
51
  const [isClassifying, setIsClassifying] = useState<boolean>(false)
@@ -75,12 +75,12 @@ function ImageClassification() {
75
  image: example.url,
76
  model: modelInfo.id,
77
  dtype: selectedQuantization ?? 'fp32',
78
- topK
79
  }
80
 
81
  activeWorker.postMessage(message)
82
  },
83
- [modelInfo, activeWorker, selectedQuantization, topK, updateExample]
84
  )
85
 
86
  const handleFileSelect = useCallback(
 
45
  removeExample,
46
  updateExample,
47
  clearExamples,
48
+ config
49
  } = useImageClassification()
50
 
51
  const [isClassifying, setIsClassifying] = useState<boolean>(false)
 
75
  image: example.url,
76
  model: modelInfo.id,
77
  dtype: selectedQuantization ?? 'fp32',
78
+ config
79
  }
80
 
81
  activeWorker.postMessage(message)
82
  },
83
+ [modelInfo, activeWorker, selectedQuantization, config, updateExample]
84
  )
85
 
86
  const handleFileSelect = useCallback(
src/components/pipelines/ImageClassificationConfig.tsx CHANGED
@@ -3,7 +3,7 @@ import { useImageClassification } from '../../contexts/ImageClassificationContex
3
  import { Slider } from '../ui/slider'
4
 
5
  const ImageClassificationConfig = () => {
6
- const { topK, setTopK } = useImageClassification()
7
 
8
  return (
9
  <div className="space-y-4">
@@ -14,14 +14,14 @@ const ImageClassificationConfig = () => {
14
  <div className="space-y-3">
15
  <div>
16
  <label className="block text-sm font-medium text-foreground/80 mb-1">
17
- Top K Predictions: {topK}
18
  </label>
19
  <Slider
20
- defaultValue={[topK]}
21
  min={1}
22
  max={10}
23
  step={1}
24
- onValueChange={(value) => setTopK(value[0])}
25
  className="w-full rounded-lg"
26
  />
27
  <div className="flex justify-between text-xs text-muted-foreground/60 mt-1">
 
3
  import { Slider } from '../ui/slider'
4
 
5
  const ImageClassificationConfig = () => {
6
+ const { config, setConfig } = useImageClassification()
7
 
8
  return (
9
  <div className="space-y-4">
 
14
  <div className="space-y-3">
15
  <div>
16
  <label className="block text-sm font-medium text-foreground/80 mb-1">
17
+ Top K Predictions: {config.top_k}
18
  </label>
19
  <Slider
20
+ defaultValue={[config.top_k]}
21
  min={1}
22
  max={10}
23
  step={1}
24
+ onValueChange={(value) => setConfig({ top_k: value[0] })}
25
  className="w-full rounded-lg"
26
  />
27
  <div className="flex justify-between text-xs text-muted-foreground/60 mt-1">
src/components/pipelines/TextGeneration.tsx CHANGED
@@ -58,12 +58,8 @@ function TextGeneration() {
58
  messages: updatedMessages,
59
  hasChatTemplate: modelInfo.hasChatTemplate,
60
  model: modelInfo.id,
61
- temperature: config.temperature,
62
- max_new_tokens: config.maxTokens,
63
- top_p: config.topP,
64
- top_k: config.topK,
65
- do_sample: config.doSample,
66
- dtype: selectedQuantization ?? 'fp32'
67
  }
68
 
69
  activeWorker.postMessage(message)
@@ -87,11 +83,7 @@ function TextGeneration() {
87
  prompt: prompt.trim(),
88
  hasChatTemplate: modelInfo.hasChatTemplate,
89
  model: modelInfo.id,
90
- temperature: config.temperature,
91
- max_new_tokens: config.maxTokens,
92
- top_p: config.topP,
93
- top_k: config.topK,
94
- do_sample: config.doSample,
95
  dtype: selectedQuantization ?? 'fp32'
96
  }
97
 
 
58
  messages: updatedMessages,
59
  hasChatTemplate: modelInfo.hasChatTemplate,
60
  model: modelInfo.id,
61
+ dtype: selectedQuantization ?? 'fp32',
62
+ config
 
 
 
 
63
  }
64
 
65
  activeWorker.postMessage(message)
 
83
  prompt: prompt.trim(),
84
  hasChatTemplate: modelInfo.hasChatTemplate,
85
  model: modelInfo.id,
86
+ config,
 
 
 
 
87
  dtype: selectedQuantization ?? 'fp32'
88
  }
89
 
src/contexts/ImageClassificationContext.tsx CHANGED
@@ -1,6 +1,10 @@
1
  import React, { createContext, useContext, useState, useCallback } from 'react'
2
  import { ImageExample } from '../types'
3
 
 
 
 
 
4
  interface ImageClassificationContextType {
5
  examples: ImageExample[]
6
  selectedExample: ImageExample | null
@@ -9,8 +13,8 @@ interface ImageClassificationContextType {
9
  removeExample: (id: string) => void
10
  updateExample: (id: string, updates: Partial<ImageExample>) => void
11
  clearExamples: () => void
12
- topK: number
13
- setTopK: (k: number) => void
14
  }
15
 
16
  const ImageClassificationContext = createContext<
@@ -38,7 +42,9 @@ export function ImageClassificationProvider({
38
  const [selectedExample, setSelectedExample] = useState<ImageExample | null>(
39
  null
40
  )
41
- const [topK, setTopK] = useState<number>(5)
 
 
42
 
43
  const addExample = useCallback((file: File) => {
44
  const id = Math.random().toString(36).substr(2, 9)
@@ -105,8 +111,8 @@ export function ImageClassificationProvider({
105
  removeExample,
106
  updateExample,
107
  clearExamples,
108
- topK,
109
- setTopK
110
  }
111
 
112
  return (
 
1
  import React, { createContext, useContext, useState, useCallback } from 'react'
2
  import { ImageExample } from '../types'
3
 
4
+ interface ImageClassificationConfig {
5
+ top_k: number
6
+ }
7
+
8
  interface ImageClassificationContextType {
9
  examples: ImageExample[]
10
  selectedExample: ImageExample | null
 
13
  removeExample: (id: string) => void
14
  updateExample: (id: string, updates: Partial<ImageExample>) => void
15
  clearExamples: () => void
16
+ config: ImageClassificationConfig
17
+ setConfig: React.Dispatch<React.SetStateAction<ImageClassificationConfig>>
18
  }
19
 
20
  const ImageClassificationContext = createContext<
 
42
  const [selectedExample, setSelectedExample] = useState<ImageExample | null>(
43
  null
44
  )
45
+ const [config, setConfig] = useState<ImageClassificationConfig>({
46
+ top_k: 5
47
+ })
48
 
49
  const addExample = useCallback((file: File) => {
50
  const id = Math.random().toString(36).substr(2, 9)
 
111
  removeExample,
112
  updateExample,
113
  clearExamples,
114
+ config,
115
+ setConfig
116
  }
117
 
118
  return (
src/types.ts CHANGED
@@ -56,11 +56,13 @@ export interface TextGenerationWorkerInput {
56
  messages?: ChatMessage[]
57
  hasChatTemplate: boolean
58
  model: string
59
- temperature?: number
60
- max_new_tokens?: number
61
- top_p?: number
62
- top_k?: number
63
- do_sample?: boolean
 
 
64
  dtype: QuantizationType
65
  }
66
 
@@ -80,7 +82,9 @@ export interface ImageClassificationWorkerInput {
80
  image: string | ImageData | HTMLImageElement | HTMLCanvasElement
81
  model: string
82
  dtype: QuantizationType
83
- topK?: number
 
 
84
  }
85
 
86
  export interface ImageClassificationResult {
 
56
  messages?: ChatMessage[]
57
  hasChatTemplate: boolean
58
  model: string
59
+ config?: {
60
+ temperature?: number
61
+ max_new_tokens?: number
62
+ top_p?: number
63
+ top_k?: number
64
+ do_sample?: boolean
65
+ }
66
  dtype: QuantizationType
67
  }
68
 
 
82
  image: string | ImageData | HTMLImageElement | HTMLCanvasElement
83
  model: string
84
  dtype: QuantizationType
85
+ config: {
86
+ top_k?: number
87
+ }
88
  }
89
 
90
  export interface ImageClassificationResult {