Vokturz commited on
Commit
8d9b8a5
·
1 Parent(s): 79eafc9

Add support for Style TTS2 models in code examples

Browse files
src/components/ModelCode.tsx CHANGED
@@ -100,16 +100,31 @@ const ModelCode = ({ isCodeModalOpen, setIsCodeModalOpen }: ModelCodeProps) => {
100
  top_k: 5
101
  }
102
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  }
104
 
105
- const jsCode = `import { pipeline } from '@huggingface/transformers';
106
 
107
  const ${classType} = pipeline('${pipeline}', '${modelInfo.name}', {
108
  dtype: '${selectedQuantization}',
109
  device: 'webgpu' // 'wasm'
110
  });
111
  const result = await ${classType}(${modelInfo.hasChatTemplate ? exampleData : "'" + exampleData + "'"}, ${JSON.stringify(config, null, 2)});
112
- console.log(result);
113
  `
114
 
115
  const configPython = Object.entries(config)
@@ -119,12 +134,34 @@ console.log(result);
119
  )
120
  .join(', ')
121
 
122
- const pythonCode = `from transformers import pipeline
123
 
124
  ${classType} = pipeline("${pipeline}", model="${modelInfo.name}")
125
  result = ${classType}(${modelInfo.hasChatTemplate ? exampleData : '"' + exampleData + '"'}, ${configPython})
126
- print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  `
 
128
 
129
  const copyToClipboard = (text: string) => {
130
  navigator.clipboard.writeText(text)
@@ -132,6 +169,7 @@ print(result)
132
  setTimeout(() => setIsCopied(false), 2000)
133
  }
134
  const pipelineName = pipeline
 
135
  .split('-')
136
  .map((word, index) => word.charAt(0).toUpperCase() + word.slice(1))
137
  .join('')
@@ -144,8 +182,19 @@ print(result)
144
  title={title}
145
  maxWidth="5xl"
146
  >
147
- {/* ... (all your modal content JSX is unchanged) */}
148
  <div className="text-sm max-w-none px-4">
 
 
 
 
 
 
 
 
 
 
 
 
149
  <div className="flex flex-row">
150
  <img src="/javascript-logo.svg" className="w-6 h-6 mr-1 rounded" />
151
  <h2 className="text-lg font-medium mb-2">Javascript</h2>
@@ -153,7 +202,7 @@ print(result)
153
  <div className="flex flex-row items-center text-sm hover:underline text-foreground/60">
154
  <Link className="h-3 w-3 mr-2" />
155
  <a
156
- href={`https://huggingface.co/docs/transformers.js/api/pipelines#pipelines${pipeline.replace(/-/g, '')}pipeline`}
157
  target="_blank"
158
  rel="noopener noreferrer"
159
  >
 
100
  top_k: 5
101
  }
102
  break
103
+ case 'text-to-speech':
104
+ classType = 'synthesizer'
105
+ exampleData =
106
+ "Life is like a box of chocolates. You never know what you're gonna get."
107
+ if (modelInfo.isStyleTTS2) {
108
+ config = {
109
+ voice: 'af_heart'
110
+ }
111
+ } else {
112
+ config = {
113
+ speaker_embeddings:
114
+ 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin'
115
+ }
116
+ }
117
+ break
118
  }
119
 
120
+ let jsCode = `import { pipeline } from '@huggingface/transformers';
121
 
122
  const ${classType} = pipeline('${pipeline}', '${modelInfo.name}', {
123
  dtype: '${selectedQuantization}',
124
  device: 'webgpu' // 'wasm'
125
  });
126
  const result = await ${classType}(${modelInfo.hasChatTemplate ? exampleData : "'" + exampleData + "'"}, ${JSON.stringify(config, null, 2)});
127
+ ${pipeline === 'text-to-speech' ? "result.save('audio.wav')" : 'console.log(result);'}
128
  `
129
 
130
  const configPython = Object.entries(config)
 
134
  )
135
  .join(', ')
136
 
137
+ let pythonCode = `from transformers import pipeline
138
 
139
  ${classType} = pipeline("${pipeline}", model="${modelInfo.name}")
140
  result = ${classType}(${modelInfo.hasChatTemplate ? exampleData : '"' + exampleData + '"'}, ${configPython})
141
+ ${pipeline === 'text-to-speech' ? 'audio = result["audio"]' : 'print(result)'}
142
+ `
143
+
144
+ if (modelInfo.isStyleTTS2) {
145
+ jsCode = `
146
+ import { KokoroTTS } from "kokoro-js";
147
+ const tts = await KokoroTTS.from_pretrained('${modelInfo.name}', {
148
+ dtype: '${selectedQuantization}',
149
+ device: 'webgpu' // 'wasm'
150
+ });
151
+
152
+ const audio = await tts.generate("${exampleData}", ${JSON.stringify(config, null, 2)});
153
+ audio.save("audio.wav");
154
+ `
155
+
156
+ pythonCode = `!pip install -q kokoro>=0.9.4 soundfile
157
+ from kokoro import KPipeline
158
+
159
+ pipeline = KPipeline(lang_code='a')
160
+ generator = pipeline("${exampleData}", voice='af_heart')
161
+ for i, (gs, ps, audio) in enumerate(generator):
162
+ print(i, gs, ps)
163
  `
164
+ }
165
 
166
  const copyToClipboard = (text: string) => {
167
  navigator.clipboard.writeText(text)
 
169
  setTimeout(() => setIsCopied(false), 2000)
170
  }
171
  const pipelineName = pipeline
172
+ .replace('speech', 'audio')
173
  .split('-')
174
  .map((word, index) => word.charAt(0).toUpperCase() + word.slice(1))
175
  .join('')
 
182
  title={title}
183
  maxWidth="5xl"
184
  >
 
185
  <div className="text-sm max-w-none px-4">
186
+ {modelInfo.isStyleTTS2 && (
187
+ <div className="flex flex-row items-center text-sm hover:underline text-foreground/60 mb-4">
188
+ <a
189
+ href={`https://github.com/hexgrad/kokoro`}
190
+ target="_blank"
191
+ rel="noopener noreferrer"
192
+ >
193
+ Check Kokoro github for more info about Style TTS2 models
194
+ </a>
195
+ </div>
196
+ )}
197
+
198
  <div className="flex flex-row">
199
  <img src="/javascript-logo.svg" className="w-6 h-6 mr-1 rounded" />
200
  <h2 className="text-lg font-medium mb-2">Javascript</h2>
 
202
  <div className="flex flex-row items-center text-sm hover:underline text-foreground/60">
203
  <Link className="h-3 w-3 mr-2" />
204
  <a
205
+ href={`https://huggingface.co/docs/transformers.js/api/pipelines#pipelines${pipeline.replace(/-/g, '').replace('speech', 'audio')}pipeline`}
206
  target="_blank"
207
  rel="noopener noreferrer"
208
  >
src/components/pipelines/TextToSpeechConfig.tsx CHANGED
@@ -30,7 +30,7 @@ const TextToSpeechConfig: React.FC<TextToSpeechConfigProps> = ({
30
  Select Voice
31
  </Label>
32
  <Select
33
- value={config.voice}
34
  onValueChange={(value) =>
35
  setConfig((prev) => ({
36
  ...prev,
 
30
  Select Voice
31
  </Label>
32
  <Select
33
+ value={config.voice || ''}
34
  onValueChange={(value) =>
35
  setConfig((prev) => ({
36
  ...prev,