Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
machineuser
commited on
Commit
•
94753b6
1
Parent(s):
ec4dcd5
Sync widgets demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- packages/inference/.eslintignore +3 -0
- packages/inference/.prettierignore +6 -0
- packages/inference/LICENSE +21 -0
- packages/inference/README.md +519 -0
- packages/inference/package.json +59 -0
- packages/inference/pnpm-lock.yaml +19 -0
- packages/inference/scripts/generate-dts.ts +179 -0
- packages/inference/src/HfInference.ts +67 -0
- packages/inference/src/index.ts +4 -0
- packages/inference/src/lib/InferenceOutputError.ts +8 -0
- packages/inference/src/lib/getDefaultTask.ts +61 -0
- packages/inference/src/lib/isUrl.ts +3 -0
- packages/inference/src/lib/makeRequestOptions.ts +113 -0
- packages/inference/src/tasks/audio/audioClassification.ts +44 -0
- packages/inference/src/tasks/audio/audioToAudio.ts +49 -0
- packages/inference/src/tasks/audio/automaticSpeechRecognition.ts +36 -0
- packages/inference/src/tasks/audio/textToSpeech.ts +28 -0
- packages/inference/src/tasks/custom/request.ts +41 -0
- packages/inference/src/tasks/custom/streamingRequest.ts +82 -0
- packages/inference/src/tasks/cv/imageClassification.ts +43 -0
- packages/inference/src/tasks/cv/imageSegmentation.ts +48 -0
- packages/inference/src/tasks/cv/imageToImage.ts +86 -0
- packages/inference/src/tasks/cv/imageToText.ts +35 -0
- packages/inference/src/tasks/cv/objectDetection.ts +61 -0
- packages/inference/src/tasks/cv/textToImage.ts +51 -0
- packages/inference/src/tasks/cv/zeroShotImageClassification.ts +58 -0
- packages/inference/src/tasks/index.ts +40 -0
- packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts +73 -0
- packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts +59 -0
- packages/inference/src/tasks/nlp/featureExtraction.ts +52 -0
- packages/inference/src/tasks/nlp/fillMask.ts +51 -0
- packages/inference/src/tasks/nlp/questionAnswering.ts +53 -0
- packages/inference/src/tasks/nlp/sentenceSimilarity.ts +40 -0
- packages/inference/src/tasks/nlp/summarization.ts +62 -0
- packages/inference/src/tasks/nlp/tableQuestionAnswering.ts +61 -0
- packages/inference/src/tasks/nlp/textClassification.ts +42 -0
- packages/inference/src/tasks/nlp/textGeneration.ts +22 -0
- packages/inference/src/tasks/nlp/textGenerationStream.ts +96 -0
- packages/inference/src/tasks/nlp/tokenClassification.ts +83 -0
- packages/inference/src/tasks/nlp/translation.ts +34 -0
- packages/inference/src/tasks/nlp/zeroShotClassification.ts +58 -0
- packages/inference/src/tasks/tabular/tabularClassification.ts +37 -0
- packages/inference/src/tasks/tabular/tabularRegression.ts +37 -0
- packages/inference/src/types.ts +61 -0
- packages/inference/src/utils/distributive-omit.d.ts +15 -0
- packages/inference/src/utils/omit.ts +11 -0
- packages/inference/src/utils/pick.ts +13 -0
- packages/inference/src/utils/toArray.ts +6 -0
- packages/inference/src/utils/typedInclude.ts +3 -0
- packages/inference/src/vendor/fetch-event-source/parse.spec.ts +389 -0
packages/inference/.eslintignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
dist
|
2 |
+
tapes.json
|
3 |
+
src/vendor
|
packages/inference/.prettierignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pnpm-lock.yaml
|
2 |
+
# In order to avoid code samples to have tabs, they don't display well on npm
|
3 |
+
README.md
|
4 |
+
dist
|
5 |
+
test/tapes.json
|
6 |
+
src/vendor
|
packages/inference/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Tim Mikeladze
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
packages/inference/README.md
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🤗 Hugging Face Inference Endpoints
|
2 |
+
|
3 |
+
A Typescript powered wrapper for the Hugging Face Inference Endpoints API. Learn more about Inference Endpoints at [Hugging Face](https://huggingface.co/inference-endpoints).
|
4 |
+
It works with both [Inference API (serverless)](https://huggingface.co/docs/api-inference/index) and [Inference Endpoints (dedicated)](https://huggingface.co/docs/inference-endpoints/index).
|
5 |
+
|
6 |
+
Check out the [full documentation](https://huggingface.co/docs/huggingface.js/inference/README).
|
7 |
+
|
8 |
+
You can also try out a live [interactive notebook](https://observablehq.com/@huggingface/hello-huggingface-js-inference), see some demos on [hf.co/huggingfacejs](https://huggingface.co/huggingfacejs), or watch a [Scrimba tutorial that explains how Inference Endpoints works](https://scrimba.com/scrim/cod8248f5adfd6e129582c523).
|
9 |
+
|
10 |
+
## Getting Started
|
11 |
+
|
12 |
+
### Install
|
13 |
+
|
14 |
+
#### Node
|
15 |
+
|
16 |
+
```console
|
17 |
+
npm install @huggingface/inference
|
18 |
+
|
19 |
+
pnpm add @huggingface/inference
|
20 |
+
|
21 |
+
yarn add @huggingface/inference
|
22 |
+
```
|
23 |
+
|
24 |
+
#### Deno
|
25 |
+
|
26 |
+
```ts
|
27 |
+
// esm.sh
|
28 |
+
import { HfInference } from "https://esm.sh/@huggingface/inference"
|
29 |
+
// or npm:
|
30 |
+
import { HfInference } from "npm:@huggingface/inference"
|
31 |
+
```
|
32 |
+
|
33 |
+
|
34 |
+
### Initialize
|
35 |
+
|
36 |
+
```typescript
|
37 |
+
import { HfInference } from '@huggingface/inference'
|
38 |
+
|
39 |
+
const hf = new HfInference('your access token')
|
40 |
+
```
|
41 |
+
|
42 |
+
❗**Important note:** Using an access token is optional to get started, however you will be rate limited eventually. Join [Hugging Face](https://huggingface.co/join) and then visit [access tokens](https://huggingface.co/settings/tokens) to generate your access token for **free**.
|
43 |
+
|
44 |
+
Your access token should be kept private. If you need to protect it in front-end applications, we suggest setting up a proxy server that stores the access token.
|
45 |
+
|
46 |
+
|
47 |
+
#### Tree-shaking
|
48 |
+
|
49 |
+
You can import the functions you need directly from the module instead of using the `HfInference` class.
|
50 |
+
|
51 |
+
```ts
|
52 |
+
import { textGeneration } from "@huggingface/inference";
|
53 |
+
|
54 |
+
await textGeneration({
|
55 |
+
accessToken: "hf_...",
|
56 |
+
model: "model_or_endpoint",
|
57 |
+
inputs: ...,
|
58 |
+
parameters: ...
|
59 |
+
})
|
60 |
+
```
|
61 |
+
|
62 |
+
This will enable tree-shaking by your bundler.
|
63 |
+
|
64 |
+
## Natural Language Processing
|
65 |
+
|
66 |
+
### Fill Mask
|
67 |
+
|
68 |
+
Tries to fill in a hole with a missing word (token to be precise).
|
69 |
+
|
70 |
+
```typescript
|
71 |
+
await hf.fillMask({
|
72 |
+
model: 'bert-base-uncased',
|
73 |
+
inputs: '[MASK] world!'
|
74 |
+
})
|
75 |
+
```
|
76 |
+
|
77 |
+
### Summarization
|
78 |
+
|
79 |
+
Summarizes longer text into shorter text. Be careful, some models have a maximum length of input.
|
80 |
+
|
81 |
+
```typescript
|
82 |
+
await hf.summarization({
|
83 |
+
model: 'facebook/bart-large-cnn',
|
84 |
+
inputs:
|
85 |
+
'The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930.',
|
86 |
+
parameters: {
|
87 |
+
max_length: 100
|
88 |
+
}
|
89 |
+
})
|
90 |
+
```
|
91 |
+
|
92 |
+
### Question Answering
|
93 |
+
|
94 |
+
Answers questions based on the context you provide.
|
95 |
+
|
96 |
+
```typescript
|
97 |
+
await hf.questionAnswering({
|
98 |
+
model: 'deepset/roberta-base-squad2',
|
99 |
+
inputs: {
|
100 |
+
question: 'What is the capital of France?',
|
101 |
+
context: 'The capital of France is Paris.'
|
102 |
+
}
|
103 |
+
})
|
104 |
+
```
|
105 |
+
|
106 |
+
### Table Question Answering
|
107 |
+
|
108 |
+
```typescript
|
109 |
+
await hf.tableQuestionAnswering({
|
110 |
+
model: 'google/tapas-base-finetuned-wtq',
|
111 |
+
inputs: {
|
112 |
+
query: 'How many stars does the transformers repository have?',
|
113 |
+
table: {
|
114 |
+
Repository: ['Transformers', 'Datasets', 'Tokenizers'],
|
115 |
+
Stars: ['36542', '4512', '3934'],
|
116 |
+
Contributors: ['651', '77', '34'],
|
117 |
+
'Programming language': ['Python', 'Python', 'Rust, Python and NodeJS']
|
118 |
+
}
|
119 |
+
}
|
120 |
+
})
|
121 |
+
```
|
122 |
+
|
123 |
+
### Text Classification
|
124 |
+
|
125 |
+
Often used for sentiment analysis, this method will assign labels to the given text along with a probability score of that label.
|
126 |
+
|
127 |
+
```typescript
|
128 |
+
await hf.textClassification({
|
129 |
+
model: 'distilbert-base-uncased-finetuned-sst-2-english',
|
130 |
+
inputs: 'I like you. I love you.'
|
131 |
+
})
|
132 |
+
```
|
133 |
+
|
134 |
+
### Text Generation
|
135 |
+
|
136 |
+
Generates text from an input prompt.
|
137 |
+
|
138 |
+
[Demo](https://huggingface.co/spaces/huggingfacejs/streaming-text-generation)
|
139 |
+
|
140 |
+
```typescript
|
141 |
+
await hf.textGeneration({
|
142 |
+
model: 'gpt2',
|
143 |
+
inputs: 'The answer to the universe is'
|
144 |
+
})
|
145 |
+
|
146 |
+
for await (const output of hf.textGenerationStream({
|
147 |
+
model: "google/flan-t5-xxl",
|
148 |
+
inputs: 'repeat "one two three four"',
|
149 |
+
parameters: { max_new_tokens: 250 }
|
150 |
+
})) {
|
151 |
+
console.log(output.token.text, output.generated_text);
|
152 |
+
}
|
153 |
+
```
|
154 |
+
|
155 |
+
### Token Classification
|
156 |
+
|
157 |
+
Used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text.
|
158 |
+
|
159 |
+
```typescript
|
160 |
+
await hf.tokenClassification({
|
161 |
+
model: 'dbmdz/bert-large-cased-finetuned-conll03-english',
|
162 |
+
inputs: 'My name is Sarah Jessica Parker but you can call me Jessica'
|
163 |
+
})
|
164 |
+
```
|
165 |
+
|
166 |
+
### Translation
|
167 |
+
|
168 |
+
Converts text from one language to another.
|
169 |
+
|
170 |
+
```typescript
|
171 |
+
await hf.translation({
|
172 |
+
model: 't5-base',
|
173 |
+
inputs: 'My name is Wolfgang and I live in Berlin'
|
174 |
+
})
|
175 |
+
|
176 |
+
await hf.translation({
|
177 |
+
model: 'facebook/mbart-large-50-many-to-many-mmt',
|
178 |
+
inputs: textToTranslate,
|
179 |
+
parameters: {
|
180 |
+
"src_lang": "en_XX",
|
181 |
+
"tgt_lang": "fr_XX"
|
182 |
+
}
|
183 |
+
})
|
184 |
+
```
|
185 |
+
|
186 |
+
### Zero-Shot Classification
|
187 |
+
|
188 |
+
Checks how well an input text fits into a set of labels you provide.
|
189 |
+
|
190 |
+
```typescript
|
191 |
+
await hf.zeroShotClassification({
|
192 |
+
model: 'facebook/bart-large-mnli',
|
193 |
+
inputs: [
|
194 |
+
'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!'
|
195 |
+
],
|
196 |
+
parameters: { candidate_labels: ['refund', 'legal', 'faq'] }
|
197 |
+
})
|
198 |
+
```
|
199 |
+
|
200 |
+
### Conversational
|
201 |
+
|
202 |
+
This task corresponds to any chatbot-like structure. Models tend to have shorter max_length, so please check with caution when using a given model if you need long-range dependency or not.
|
203 |
+
|
204 |
+
```typescript
|
205 |
+
await hf.conversational({
|
206 |
+
model: 'microsoft/DialoGPT-large',
|
207 |
+
inputs: {
|
208 |
+
past_user_inputs: ['Which movie is the best ?'],
|
209 |
+
generated_responses: ['It is Die Hard for sure.'],
|
210 |
+
text: 'Can you explain why ?'
|
211 |
+
}
|
212 |
+
})
|
213 |
+
```
|
214 |
+
|
215 |
+
### Sentence Similarity
|
216 |
+
|
217 |
+
Calculate the semantic similarity between one text and a list of other sentences.
|
218 |
+
|
219 |
+
```typescript
|
220 |
+
await hf.sentenceSimilarity({
|
221 |
+
model: 'sentence-transformers/paraphrase-xlm-r-multilingual-v1',
|
222 |
+
inputs: {
|
223 |
+
source_sentence: 'That is a happy person',
|
224 |
+
sentences: [
|
225 |
+
'That is a happy dog',
|
226 |
+
'That is a very happy person',
|
227 |
+
'Today is a sunny day'
|
228 |
+
]
|
229 |
+
}
|
230 |
+
})
|
231 |
+
```
|
232 |
+
|
233 |
+
## Audio
|
234 |
+
|
235 |
+
### Automatic Speech Recognition
|
236 |
+
|
237 |
+
Transcribes speech from an audio file.
|
238 |
+
|
239 |
+
[Demo](https://huggingface.co/spaces/huggingfacejs/speech-recognition-vue)
|
240 |
+
|
241 |
+
```typescript
|
242 |
+
await hf.automaticSpeechRecognition({
|
243 |
+
model: 'facebook/wav2vec2-large-960h-lv60-self',
|
244 |
+
data: readFileSync('test/sample1.flac')
|
245 |
+
})
|
246 |
+
```
|
247 |
+
|
248 |
+
### Audio Classification
|
249 |
+
|
250 |
+
Assigns labels to the given audio along with a probability score of that label.
|
251 |
+
|
252 |
+
[Demo](https://huggingface.co/spaces/huggingfacejs/audio-classification-vue)
|
253 |
+
|
254 |
+
```typescript
|
255 |
+
await hf.audioClassification({
|
256 |
+
model: 'superb/hubert-large-superb-er',
|
257 |
+
data: readFileSync('test/sample1.flac')
|
258 |
+
})
|
259 |
+
```
|
260 |
+
|
261 |
+
### Text To Speech
|
262 |
+
|
263 |
+
Generates natural-sounding speech from text input.
|
264 |
+
|
265 |
+
[Interactive tutorial](https://scrimba.com/scrim/co8da4d23b49b648f77f4848a?pl=pkVnrP7uP)
|
266 |
+
|
267 |
+
```typescript
|
268 |
+
await hf.textToSpeech({
|
269 |
+
model: 'espnet/kan-bayashi_ljspeech_vits',
|
270 |
+
inputs: 'Hello world!'
|
271 |
+
})
|
272 |
+
```
|
273 |
+
|
274 |
+
### Audio To Audio
|
275 |
+
|
276 |
+
Outputs one or multiple generated audios from an input audio, commonly used for speech enhancement and source separation.
|
277 |
+
|
278 |
+
```typescript
|
279 |
+
await hf.audioToAudio({
|
280 |
+
model: 'speechbrain/sepformer-wham',
|
281 |
+
data: readFileSync('test/sample1.flac')
|
282 |
+
})
|
283 |
+
```
|
284 |
+
|
285 |
+
## Computer Vision
|
286 |
+
|
287 |
+
### Image Classification
|
288 |
+
|
289 |
+
Assigns labels to a given image along with a probability score of that label.
|
290 |
+
|
291 |
+
[Demo](https://huggingface.co/spaces/huggingfacejs/image-classification-vue)
|
292 |
+
|
293 |
+
```typescript
|
294 |
+
await hf.imageClassification({
|
295 |
+
data: readFileSync('test/cheetah.png'),
|
296 |
+
model: 'google/vit-base-patch16-224'
|
297 |
+
})
|
298 |
+
```
|
299 |
+
|
300 |
+
### Object Detection
|
301 |
+
|
302 |
+
Detects objects within an image and returns labels with corresponding bounding boxes and probability scores.
|
303 |
+
|
304 |
+
[Demo](https://huggingface.co/spaces/huggingfacejs/object-detection-vue)
|
305 |
+
|
306 |
+
```typescript
|
307 |
+
await hf.objectDetection({
|
308 |
+
data: readFileSync('test/cats.png'),
|
309 |
+
model: 'facebook/detr-resnet-50'
|
310 |
+
})
|
311 |
+
```
|
312 |
+
|
313 |
+
### Image Segmentation
|
314 |
+
|
315 |
+
Detects segments within an image and returns labels with corresponding bounding boxes and probability scores.
|
316 |
+
|
317 |
+
```typescript
|
318 |
+
await hf.imageSegmentation({
|
319 |
+
data: readFileSync('test/cats.png'),
|
320 |
+
model: 'facebook/detr-resnet-50-panoptic'
|
321 |
+
})
|
322 |
+
```
|
323 |
+
|
324 |
+
### Image To Text
|
325 |
+
|
326 |
+
Outputs text from a given image, commonly used for captioning or optical character recognition.
|
327 |
+
|
328 |
+
```typescript
|
329 |
+
await hf.imageToText({
|
330 |
+
data: readFileSync('test/cats.png'),
|
331 |
+
model: 'nlpconnect/vit-gpt2-image-captioning'
|
332 |
+
})
|
333 |
+
```
|
334 |
+
|
335 |
+
### Text To Image
|
336 |
+
|
337 |
+
Creates an image from a text prompt.
|
338 |
+
|
339 |
+
[Demo](https://huggingface.co/spaces/huggingfacejs/image-to-text)
|
340 |
+
|
341 |
+
```typescript
|
342 |
+
await hf.textToImage({
|
343 |
+
inputs: 'award winning high resolution photo of a giant tortoise/((ladybird)) hybrid, [trending on artstation]',
|
344 |
+
model: 'stabilityai/stable-diffusion-2',
|
345 |
+
parameters: {
|
346 |
+
negative_prompt: 'blurry',
|
347 |
+
}
|
348 |
+
})
|
349 |
+
```
|
350 |
+
|
351 |
+
### Image To Image
|
352 |
+
|
353 |
+
Image-to-image is the task of transforming a source image to match the characteristics of a target image or a target image domain.
|
354 |
+
|
355 |
+
[Interactive tutorial](https://scrimba.com/scrim/co4834bf9a91cc81cfab07969?pl=pkVnrP7uP)
|
356 |
+
|
357 |
+
```typescript
|
358 |
+
await hf.imageToImage({
|
359 |
+
inputs: new Blob([readFileSync("test/stormtrooper_depth.png")]),
|
360 |
+
parameters: {
|
361 |
+
prompt: "elmo's lecture",
|
362 |
+
},
|
363 |
+
model: "lllyasviel/sd-controlnet-depth",
|
364 |
+
});
|
365 |
+
```
|
366 |
+
|
367 |
+
### Zero Shot Image Classification
|
368 |
+
|
369 |
+
Checks how well an input image fits into a set of labels you provide.
|
370 |
+
|
371 |
+
```typescript
|
372 |
+
await hf.zeroShotImageClassification({
|
373 |
+
model: 'openai/clip-vit-large-patch14-336',
|
374 |
+
inputs: {
|
375 |
+
image: await (await fetch('https://placekitten.com/300/300')).blob()
|
376 |
+
},
|
377 |
+
parameters: {
|
378 |
+
candidate_labels: ['cat', 'dog']
|
379 |
+
}
|
380 |
+
})
|
381 |
+
```
|
382 |
+
|
383 |
+
## Multimodal
|
384 |
+
|
385 |
+
### Feature Extraction
|
386 |
+
|
387 |
+
This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search.
|
388 |
+
|
389 |
+
```typescript
|
390 |
+
await hf.featureExtraction({
|
391 |
+
model: "sentence-transformers/distilbert-base-nli-mean-tokens",
|
392 |
+
inputs: "That is a happy person",
|
393 |
+
});
|
394 |
+
```
|
395 |
+
|
396 |
+
### Visual Question Answering
|
397 |
+
|
398 |
+
Visual Question Answering is the task of answering open-ended questions based on an image. They output natural language responses to natural language questions.
|
399 |
+
|
400 |
+
[Demo](https://huggingface.co/spaces/huggingfacejs/doc-vis-qa)
|
401 |
+
|
402 |
+
```typescript
|
403 |
+
await hf.visualQuestionAnswering({
|
404 |
+
model: 'dandelin/vilt-b32-finetuned-vqa',
|
405 |
+
inputs: {
|
406 |
+
question: 'How many cats are lying down?',
|
407 |
+
image: await (await fetch('https://placekitten.com/300/300')).blob()
|
408 |
+
}
|
409 |
+
})
|
410 |
+
```
|
411 |
+
|
412 |
+
### Document Question Answering
|
413 |
+
|
414 |
+
Document question answering models take a (document, question) pair as input and return an answer in natural language.
|
415 |
+
|
416 |
+
[Demo](https://huggingface.co/spaces/huggingfacejs/doc-vis-qa)
|
417 |
+
|
418 |
+
```typescript
|
419 |
+
await hf.documentQuestionAnswering({
|
420 |
+
model: 'impira/layoutlm-document-qa',
|
421 |
+
inputs: {
|
422 |
+
question: 'Invoice number?',
|
423 |
+
image: await (await fetch('https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png')).blob(),
|
424 |
+
}
|
425 |
+
})
|
426 |
+
```
|
427 |
+
|
428 |
+
## Tabular
|
429 |
+
|
430 |
+
### Tabular Regression
|
431 |
+
|
432 |
+
Tabular regression is the task of predicting a numerical value given a set of attributes.
|
433 |
+
|
434 |
+
```typescript
|
435 |
+
await hf.tabularRegression({
|
436 |
+
model: "scikit-learn/Fish-Weight",
|
437 |
+
inputs: {
|
438 |
+
data: {
|
439 |
+
"Height": ["11.52", "12.48", "12.3778"],
|
440 |
+
"Length1": ["23.2", "24", "23.9"],
|
441 |
+
"Length2": ["25.4", "26.3", "26.5"],
|
442 |
+
"Length3": ["30", "31.2", "31.1"],
|
443 |
+
"Species": ["Bream", "Bream", "Bream"],
|
444 |
+
"Width": ["4.02", "4.3056", "4.6961"]
|
445 |
+
},
|
446 |
+
},
|
447 |
+
})
|
448 |
+
```
|
449 |
+
|
450 |
+
### Tabular Classification
|
451 |
+
|
452 |
+
Tabular classification is the task of classifying a target category (a group) based on set of attributes.
|
453 |
+
|
454 |
+
```typescript
|
455 |
+
await hf.tabularClassification({
|
456 |
+
model: "vvmnnnkv/wine-quality",
|
457 |
+
inputs: {
|
458 |
+
data: {
|
459 |
+
"fixed_acidity": ["7.4", "7.8", "10.3"],
|
460 |
+
"volatile_acidity": ["0.7", "0.88", "0.32"],
|
461 |
+
"citric_acid": ["0", "0", "0.45"],
|
462 |
+
"residual_sugar": ["1.9", "2.6", "6.4"],
|
463 |
+
"chlorides": ["0.076", "0.098", "0.073"],
|
464 |
+
"free_sulfur_dioxide": ["11", "25", "5"],
|
465 |
+
"total_sulfur_dioxide": ["34", "67", "13"],
|
466 |
+
"density": ["0.9978", "0.9968", "0.9976"],
|
467 |
+
"pH": ["3.51", "3.2", "3.23"],
|
468 |
+
"sulphates": ["0.56", "0.68", "0.82"],
|
469 |
+
"alcohol": ["9.4", "9.8", "12.6"]
|
470 |
+
},
|
471 |
+
},
|
472 |
+
})
|
473 |
+
```
|
474 |
+
|
475 |
+
## Custom Calls
|
476 |
+
|
477 |
+
For models with custom parameters / outputs.
|
478 |
+
|
479 |
+
```typescript
|
480 |
+
await hf.request({
|
481 |
+
model: 'my-custom-model',
|
482 |
+
inputs: 'hello world',
|
483 |
+
parameters: {
|
484 |
+
custom_param: 'some magic',
|
485 |
+
}
|
486 |
+
})
|
487 |
+
|
488 |
+
// Custom streaming call, for models with custom parameters / outputs
|
489 |
+
for await (const output of hf.streamingRequest({
|
490 |
+
model: 'my-custom-model',
|
491 |
+
inputs: 'hello world',
|
492 |
+
parameters: {
|
493 |
+
custom_param: 'some magic',
|
494 |
+
}
|
495 |
+
})) {
|
496 |
+
...
|
497 |
+
}
|
498 |
+
```
|
499 |
+
|
500 |
+
## Custom Inference Endpoints
|
501 |
+
|
502 |
+
Learn more about using your own inference endpoints [here](https://hf.co/docs/inference-endpoints/)
|
503 |
+
|
504 |
+
```typescript
|
505 |
+
const gpt2 = hf.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
|
506 |
+
const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the universe is'});
|
507 |
+
```
|
508 |
+
|
509 |
+
## Running tests
|
510 |
+
|
511 |
+
```console
|
512 |
+
HF_TOKEN="your access token" pnpm run test
|
513 |
+
```
|
514 |
+
|
515 |
+
## Finding appropriate models
|
516 |
+
|
517 |
+
We have an informative documentation project called [Tasks](https://huggingface.co/tasks) to list available models for each task and explain how each task works in detail.
|
518 |
+
|
519 |
+
It also contains demos, example outputs, and other resources should you want to dig deeper into the ML side of things.
|
packages/inference/package.json
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "@huggingface/inference",
|
3 |
+
"version": "2.6.4",
|
4 |
+
"packageManager": "pnpm@8.10.5",
|
5 |
+
"license": "MIT",
|
6 |
+
"author": "Tim Mikeladze <tim.mikeladze@gmail.com>",
|
7 |
+
"description": "Typescript wrapper for the Hugging Face Inference Endpoints & Inference API",
|
8 |
+
"repository": {
|
9 |
+
"type": "git",
|
10 |
+
"url": "https://github.com/huggingface/huggingface.js.git"
|
11 |
+
},
|
12 |
+
"publishConfig": {
|
13 |
+
"access": "public"
|
14 |
+
},
|
15 |
+
"keywords": [
|
16 |
+
"hugging face",
|
17 |
+
"hugging face typescript",
|
18 |
+
"huggingface",
|
19 |
+
"huggingface-inference-api",
|
20 |
+
"huggingface-inference-api-typescript",
|
21 |
+
"inference",
|
22 |
+
"ai"
|
23 |
+
],
|
24 |
+
"engines": {
|
25 |
+
"node": ">=18"
|
26 |
+
},
|
27 |
+
"files": [
|
28 |
+
"dist",
|
29 |
+
"src"
|
30 |
+
],
|
31 |
+
"source": "src/index.ts",
|
32 |
+
"types": "./dist/index.d.ts",
|
33 |
+
"main": "./dist/index.cjs",
|
34 |
+
"module": "./dist/index.js",
|
35 |
+
"exports": {
|
36 |
+
"types": "./dist/index.d.ts",
|
37 |
+
"require": "./dist/index.cjs",
|
38 |
+
"import": "./dist/index.js"
|
39 |
+
},
|
40 |
+
"type": "module",
|
41 |
+
"scripts": {
|
42 |
+
"build": "tsup src/index.ts --format cjs,esm --clean && pnpm run dts",
|
43 |
+
"dts": "tsx scripts/generate-dts.ts",
|
44 |
+
"lint": "eslint --quiet --fix --ext .cjs,.ts .",
|
45 |
+
"lint:check": "eslint --ext .cjs,.ts .",
|
46 |
+
"format": "prettier --write .",
|
47 |
+
"format:check": "prettier --check .",
|
48 |
+
"prepare": "pnpm run build",
|
49 |
+
"prepublishOnly": "pnpm run build",
|
50 |
+
"test": "vitest run --config vitest.config.mts",
|
51 |
+
"test:browser": "vitest run --browser.name=chrome --browser.headless --config vitest.config.mts",
|
52 |
+
"check": "tsc"
|
53 |
+
},
|
54 |
+
"devDependencies": {
|
55 |
+
"@huggingface/tasks": "workspace:^",
|
56 |
+
"@types/node": "18.13.0"
|
57 |
+
},
|
58 |
+
"resolutions": {}
|
59 |
+
}
|
packages/inference/pnpm-lock.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lockfileVersion: '6.0'
|
2 |
+
|
3 |
+
settings:
|
4 |
+
autoInstallPeers: true
|
5 |
+
excludeLinksFromLockfile: false
|
6 |
+
|
7 |
+
devDependencies:
|
8 |
+
'@huggingface/tasks':
|
9 |
+
specifier: workspace:^
|
10 |
+
version: link:../tasks
|
11 |
+
'@types/node':
|
12 |
+
specifier: 18.13.0
|
13 |
+
version: 18.13.0
|
14 |
+
|
15 |
+
packages:
|
16 |
+
|
17 |
+
/@types/node@18.13.0:
|
18 |
+
resolution: {integrity: sha512-gC3TazRzGoOnoKAhUx+Q0t8S9Tzs74z7m0ipwGpSqQrleP14hKxP4/JUeEQcD3W1/aIpnWl8pHowI7WokuZpXg==}
|
19 |
+
dev: true
|
packages/inference/scripts/generate-dts.ts
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/** Dirty script to generate pretty .d.ts */
|
2 |
+
|
3 |
+
import { readFileSync, writeFileSync, appendFileSync, readdirSync } from "node:fs";
|
4 |
+
import { TASKS_DATA } from "@huggingface/tasks";
|
5 |
+
|
6 |
+
const tasks = Object.keys(TASKS_DATA)
|
7 |
+
.sort()
|
8 |
+
.filter((task) => task !== "other");
|
9 |
+
|
10 |
+
let types = readFileSync("./src/types.ts", "utf-8");
|
11 |
+
|
12 |
+
types = types.replace(/import.* "@huggingface\/tasks";\n/g, "");
|
13 |
+
types = types.replace(' Exclude<PipelineType, "other">', ["", ...tasks.map((task) => `"${task}"`)].join("\n\t| "));
|
14 |
+
|
15 |
+
if (types.includes("PipelineType") || types.includes("@huggingface/tasks")) {
|
16 |
+
console.log(types);
|
17 |
+
console.error("Failed to parse types.ts");
|
18 |
+
process.exit(1);
|
19 |
+
}
|
20 |
+
|
21 |
+
writeFileSync("./dist/index.d.ts", types + "\n");
|
22 |
+
appendFileSync("./dist/index.d.ts", "export class InferenceOutputError extends TypeError {}" + "\n");
|
23 |
+
|
24 |
+
const dirs = readdirSync("./src/tasks");
|
25 |
+
|
26 |
+
const fns: string[] = [];
|
27 |
+
for (const dir of dirs) {
|
28 |
+
if (dir.endsWith(".ts")) {
|
29 |
+
continue;
|
30 |
+
}
|
31 |
+
const files = readdirSync(`./src/tasks/${dir}`);
|
32 |
+
for (const file of files) {
|
33 |
+
if (!file.endsWith(".ts")) {
|
34 |
+
continue;
|
35 |
+
}
|
36 |
+
|
37 |
+
const fileContent = readFileSync(`./src/tasks/${dir}/${file}`, "utf-8");
|
38 |
+
|
39 |
+
for (const type of extractTypesAndInterfaces(fileContent)) {
|
40 |
+
appendFileSync("./dist/index.d.ts", type + "\n");
|
41 |
+
}
|
42 |
+
|
43 |
+
for (const fn of extractAsyncFunctions(fileContent)) {
|
44 |
+
appendFileSync("./dist/index.d.ts", fn + "\n");
|
45 |
+
fns.push(fn);
|
46 |
+
}
|
47 |
+
}
|
48 |
+
}
|
49 |
+
|
50 |
+
appendFileSync(
|
51 |
+
"./dist/index.d.ts",
|
52 |
+
`export class HfInference {
|
53 |
+
\tconstructor(accessToken?: string, defaultOptions?: Options);
|
54 |
+
\t/**
|
55 |
+
\t * Returns copy of HfInference tied to a specified endpoint.
|
56 |
+
\t */
|
57 |
+
\tendpoint(endpointUrl: string): HfInferenceEndpoint;
|
58 |
+
` +
|
59 |
+
fns
|
60 |
+
.map(
|
61 |
+
(fn) =>
|
62 |
+
`${fn
|
63 |
+
.replace(/args: [a-zA-Z]+/, (args) => `args: Omit<${args.slice("args: ".length)}, 'accessToken'>`)
|
64 |
+
.replace("export function ", "")
|
65 |
+
.split("\n")
|
66 |
+
.map((line) => "\t" + line)
|
67 |
+
.join("\n")}`
|
68 |
+
)
|
69 |
+
.join("\n") +
|
70 |
+
"\n}\n"
|
71 |
+
);
|
72 |
+
|
73 |
+
appendFileSync(
|
74 |
+
"./dist/index.d.ts",
|
75 |
+
`export class HfInferenceEndpoint {\n\tconstructor(endpointUrl: string, accessToken?: string, defaultOptions?: Options);\n` +
|
76 |
+
fns
|
77 |
+
.map(
|
78 |
+
(fn) =>
|
79 |
+
`${fn
|
80 |
+
.replace(/args: [a-zA-Z]+/, (args) => `args: Omit<${args.slice("args: ".length)}, 'accessToken' | 'model'>`)
|
81 |
+
.replace("export function ", "")
|
82 |
+
.split("\n")
|
83 |
+
.map((line) => "\t" + line)
|
84 |
+
.join("\n")}`
|
85 |
+
)
|
86 |
+
.join("\n") +
|
87 |
+
"\n}\n"
|
88 |
+
);
|
89 |
+
|
90 |
+
function* extractTypesAndInterfaces(fileContent: string): Iterable<string> {
|
91 |
+
let index = 0;
|
92 |
+
|
93 |
+
for (const kind of ["type", "interface"]) {
|
94 |
+
while (true) {
|
95 |
+
index = fileContent.indexOf(`export ${kind} `, index);
|
96 |
+
const initialIndex = index;
|
97 |
+
if (index === -1) {
|
98 |
+
break;
|
99 |
+
}
|
100 |
+
|
101 |
+
let bracketOpen = 0;
|
102 |
+
|
103 |
+
dance: for (let i = index; i < fileContent.length; i++) {
|
104 |
+
switch (fileContent[i]) {
|
105 |
+
case "{":
|
106 |
+
bracketOpen++;
|
107 |
+
break;
|
108 |
+
case "}":
|
109 |
+
bracketOpen--;
|
110 |
+
if (bracketOpen === 0 && kind === "interface") {
|
111 |
+
// Add doc comment if present
|
112 |
+
if (fileContent[index - 2] === "/" && fileContent[index - 3] === "*") {
|
113 |
+
index = fileContent.lastIndexOf("/*", index);
|
114 |
+
}
|
115 |
+
yield fileContent.slice(index, i + 1);
|
116 |
+
index = i + 1;
|
117 |
+
break dance;
|
118 |
+
}
|
119 |
+
break;
|
120 |
+
case ";":
|
121 |
+
if (bracketOpen === 0) {
|
122 |
+
// Add doc comment if present
|
123 |
+
if (fileContent[index - 2] === "/" && fileContent[index - 3] === "*") {
|
124 |
+
index = fileContent.lastIndexOf("/*", index);
|
125 |
+
}
|
126 |
+
yield fileContent.slice(index, i + 1);
|
127 |
+
index = i + 1;
|
128 |
+
break dance;
|
129 |
+
}
|
130 |
+
break;
|
131 |
+
}
|
132 |
+
}
|
133 |
+
|
134 |
+
if (initialIndex === index) {
|
135 |
+
console.error("Failed to parse fileContent", fileContent.slice(index, index + 100));
|
136 |
+
process.exit(1);
|
137 |
+
}
|
138 |
+
}
|
139 |
+
}
|
140 |
+
}
|
141 |
+
|
142 |
+
function* extractAsyncFunctions(fileContent: string): Iterable<string> {
|
143 |
+
let index = 0;
|
144 |
+
|
145 |
+
while (true) {
|
146 |
+
index = fileContent.indexOf(`export async function`, index);
|
147 |
+
if (index === -1) {
|
148 |
+
break;
|
149 |
+
}
|
150 |
+
|
151 |
+
const typeBegin = fileContent.indexOf("): ", index);
|
152 |
+
|
153 |
+
if (typeBegin === -1) {
|
154 |
+
console.error("Failed to parse fileContent", fileContent.slice(index, index + 100));
|
155 |
+
process.exit(1);
|
156 |
+
}
|
157 |
+
|
158 |
+
const typeEnd = fileContent.indexOf(" {", typeBegin);
|
159 |
+
|
160 |
+
if (typeEnd === -1) {
|
161 |
+
console.error("Failed to parse fileContent", fileContent.slice(index, index + 100));
|
162 |
+
process.exit(1);
|
163 |
+
}
|
164 |
+
|
165 |
+
if (fileContent[index - 2] === "/" && fileContent[index - 3] === "*") {
|
166 |
+
index = fileContent.lastIndexOf("/*", index);
|
167 |
+
}
|
168 |
+
yield fileContent
|
169 |
+
.slice(index, typeEnd)
|
170 |
+
.replace("export async ", "export ")
|
171 |
+
.replace("export function*", "export function")
|
172 |
+
.trim() + ";";
|
173 |
+
index = typeEnd;
|
174 |
+
}
|
175 |
+
}
|
176 |
+
|
177 |
+
for (const distPath of ["./dist/index.js", "./dist/index.cjs"]) {
|
178 |
+
writeFileSync(distPath, '/// <reference path="./index.d.ts" />\n' + readFileSync(distPath, "utf-8"));
|
179 |
+
}
|
packages/inference/src/HfInference.ts
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import * as tasks from "./tasks";
|
2 |
+
import type { Options, RequestArgs } from "./types";
|
3 |
+
import type { DistributiveOmit } from "./utils/distributive-omit";
|
4 |
+
|
5 |
+
type Task = typeof tasks;
|
6 |
+
|
7 |
+
type TaskWithNoAccessToken = {
|
8 |
+
[key in keyof Task]: (
|
9 |
+
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken">,
|
10 |
+
options?: Parameters<Task[key]>[1]
|
11 |
+
) => ReturnType<Task[key]>;
|
12 |
+
};
|
13 |
+
|
14 |
+
type TaskWithNoAccessTokenNoModel = {
|
15 |
+
[key in keyof Task]: (
|
16 |
+
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "model">,
|
17 |
+
options?: Parameters<Task[key]>[1]
|
18 |
+
) => ReturnType<Task[key]>;
|
19 |
+
};
|
20 |
+
|
21 |
+
export class HfInference {
|
22 |
+
private readonly accessToken: string;
|
23 |
+
private readonly defaultOptions: Options;
|
24 |
+
|
25 |
+
constructor(accessToken = "", defaultOptions: Options = {}) {
|
26 |
+
this.accessToken = accessToken;
|
27 |
+
this.defaultOptions = defaultOptions;
|
28 |
+
|
29 |
+
for (const [name, fn] of Object.entries(tasks)) {
|
30 |
+
Object.defineProperty(this, name, {
|
31 |
+
enumerable: false,
|
32 |
+
value: (params: RequestArgs, options: Options) =>
|
33 |
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
34 |
+
fn({ ...params, accessToken } as any, { ...defaultOptions, ...options }),
|
35 |
+
});
|
36 |
+
}
|
37 |
+
}
|
38 |
+
|
39 |
+
/**
|
40 |
+
* Returns copy of HfInference tied to a specified endpoint.
|
41 |
+
*/
|
42 |
+
public endpoint(endpointUrl: string): HfInferenceEndpoint {
|
43 |
+
return new HfInferenceEndpoint(endpointUrl, this.accessToken, this.defaultOptions);
|
44 |
+
}
|
45 |
+
}
|
46 |
+
|
47 |
+
export class HfInferenceEndpoint {
|
48 |
+
constructor(endpointUrl: string, accessToken = "", defaultOptions: Options = {}) {
|
49 |
+
accessToken;
|
50 |
+
defaultOptions;
|
51 |
+
|
52 |
+
for (const [name, fn] of Object.entries(tasks)) {
|
53 |
+
Object.defineProperty(this, name, {
|
54 |
+
enumerable: false,
|
55 |
+
value: (params: RequestArgs, options: Options) =>
|
56 |
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
57 |
+
fn({ ...params, accessToken, model: endpointUrl } as any, { ...defaultOptions, ...options }),
|
58 |
+
});
|
59 |
+
}
|
60 |
+
}
|
61 |
+
}
|
62 |
+
|
63 |
+
// eslint-disable-next-line @typescript-eslint/no-empty-interface
|
64 |
+
export interface HfInference extends TaskWithNoAccessToken {}
|
65 |
+
|
66 |
+
// eslint-disable-next-line @typescript-eslint/no-empty-interface
|
67 |
+
export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoModel {}
|
packages/inference/src/index.ts
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export { HfInference, HfInferenceEndpoint } from "./HfInference";
|
2 |
+
export { InferenceOutputError } from "./lib/InferenceOutputError";
|
3 |
+
export * from "./types";
|
4 |
+
export * from "./tasks";
|
packages/inference/src/lib/InferenceOutputError.ts
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export class InferenceOutputError extends TypeError {
|
2 |
+
constructor(message: string) {
|
3 |
+
super(
|
4 |
+
`Invalid inference output: ${message}. Use the 'request' method with the same parameters to do a custom call with no type checking.`
|
5 |
+
);
|
6 |
+
this.name = "InferenceOutputError";
|
7 |
+
}
|
8 |
+
}
|
packages/inference/src/lib/getDefaultTask.ts
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { isUrl } from "./isUrl";
|
2 |
+
|
3 |
+
/**
|
4 |
+
* We want to make calls to the huggingface hub the least possible, eg if
|
5 |
+
* someone is calling Inference Endpoints 1000 times per second, we don't want
|
6 |
+
* to make 1000 calls to the hub to get the task name.
|
7 |
+
*/
|
8 |
+
const taskCache = new Map<string, { task: string; date: Date }>();
|
9 |
+
const CACHE_DURATION = 10 * 60 * 1000;
|
10 |
+
const MAX_CACHE_ITEMS = 1000;
|
11 |
+
export const HF_HUB_URL = "https://huggingface.co";
|
12 |
+
|
13 |
+
export interface DefaultTaskOptions {
|
14 |
+
fetch?: typeof fetch;
|
15 |
+
}
|
16 |
+
|
17 |
+
/**
|
18 |
+
* Get the default task. Use a LRU cache of 1000 items with 10 minutes expiration
|
19 |
+
* to avoid making too many calls to the HF hub.
|
20 |
+
*
|
21 |
+
* @returns The default task for the model, or `null` if it was impossible to get it
|
22 |
+
*/
|
23 |
+
export async function getDefaultTask(
|
24 |
+
model: string,
|
25 |
+
accessToken: string | undefined,
|
26 |
+
options?: DefaultTaskOptions
|
27 |
+
): Promise<string | null> {
|
28 |
+
if (isUrl(model)) {
|
29 |
+
return null;
|
30 |
+
}
|
31 |
+
|
32 |
+
const key = `${model}:${accessToken}`;
|
33 |
+
let cachedTask = taskCache.get(key);
|
34 |
+
|
35 |
+
if (cachedTask && cachedTask.date < new Date(Date.now() - CACHE_DURATION)) {
|
36 |
+
taskCache.delete(key);
|
37 |
+
cachedTask = undefined;
|
38 |
+
}
|
39 |
+
|
40 |
+
if (cachedTask === undefined) {
|
41 |
+
const modelTask = await (options?.fetch ?? fetch)(`${HF_HUB_URL}/api/models/${model}?expand[]=pipeline_tag`, {
|
42 |
+
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
|
43 |
+
})
|
44 |
+
.then((resp) => resp.json())
|
45 |
+
.then((json) => json.pipeline_tag)
|
46 |
+
.catch(() => null);
|
47 |
+
|
48 |
+
if (!modelTask) {
|
49 |
+
return null;
|
50 |
+
}
|
51 |
+
|
52 |
+
cachedTask = { task: modelTask, date: new Date() };
|
53 |
+
taskCache.set(key, { task: modelTask, date: new Date() });
|
54 |
+
|
55 |
+
if (taskCache.size > MAX_CACHE_ITEMS) {
|
56 |
+
taskCache.delete(taskCache.keys().next().value);
|
57 |
+
}
|
58 |
+
}
|
59 |
+
|
60 |
+
return cachedTask.task;
|
61 |
+
}
|
packages/inference/src/lib/isUrl.ts
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
export function isUrl(modelOrUrl: string): boolean {
|
2 |
+
return /^http(s?):/.test(modelOrUrl) || modelOrUrl.startsWith("/");
|
3 |
+
}
|
packages/inference/src/lib/makeRequestOptions.ts
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import type { InferenceTask, Options, RequestArgs } from "../types";
|
2 |
+
import { HF_HUB_URL } from "./getDefaultTask";
|
3 |
+
import { isUrl } from "./isUrl";
|
4 |
+
|
5 |
+
const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co";
|
6 |
+
|
7 |
+
/**
|
8 |
+
* Loaded from huggingface.co/api/tasks if needed
|
9 |
+
*/
|
10 |
+
let tasks: Record<string, { models: { id: string }[] }> | null = null;
|
11 |
+
|
12 |
+
/**
|
13 |
+
* Helper that prepares request arguments
|
14 |
+
*/
|
15 |
+
export async function makeRequestOptions(
|
16 |
+
args: RequestArgs & {
|
17 |
+
data?: Blob | ArrayBuffer;
|
18 |
+
stream?: boolean;
|
19 |
+
},
|
20 |
+
options?: Options & {
|
21 |
+
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
22 |
+
forceTask?: string | InferenceTask;
|
23 |
+
/** To load default model if needed */
|
24 |
+
taskHint?: InferenceTask;
|
25 |
+
}
|
26 |
+
): Promise<{ url: string; info: RequestInit }> {
|
27 |
+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
28 |
+
const { accessToken, model: _model, ...otherArgs } = args;
|
29 |
+
let { model } = args;
|
30 |
+
const { forceTask: task, includeCredentials, taskHint, ...otherOptions } = options ?? {};
|
31 |
+
|
32 |
+
const headers: Record<string, string> = {};
|
33 |
+
if (accessToken) {
|
34 |
+
headers["Authorization"] = `Bearer ${accessToken}`;
|
35 |
+
}
|
36 |
+
|
37 |
+
if (!model && !tasks && taskHint) {
|
38 |
+
const res = await fetch(`${HF_HUB_URL}/api/tasks`);
|
39 |
+
|
40 |
+
if (res.ok) {
|
41 |
+
tasks = await res.json();
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
45 |
+
if (!model && tasks && taskHint) {
|
46 |
+
const taskInfo = tasks[taskHint];
|
47 |
+
if (taskInfo) {
|
48 |
+
model = taskInfo.models[0].id;
|
49 |
+
}
|
50 |
+
}
|
51 |
+
|
52 |
+
if (!model) {
|
53 |
+
throw new Error("No model provided, and no default model found for this task");
|
54 |
+
}
|
55 |
+
|
56 |
+
const binary = "data" in args && !!args.data;
|
57 |
+
|
58 |
+
if (!binary) {
|
59 |
+
headers["Content-Type"] = "application/json";
|
60 |
+
} else {
|
61 |
+
if (options?.wait_for_model) {
|
62 |
+
headers["X-Wait-For-Model"] = "true";
|
63 |
+
}
|
64 |
+
if (options?.use_cache === false) {
|
65 |
+
headers["X-Use-Cache"] = "false";
|
66 |
+
}
|
67 |
+
if (options?.dont_load_model) {
|
68 |
+
headers["X-Load-Model"] = "0";
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
const url = (() => {
|
73 |
+
if (isUrl(model)) {
|
74 |
+
return model;
|
75 |
+
}
|
76 |
+
|
77 |
+
if (task) {
|
78 |
+
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
|
79 |
+
}
|
80 |
+
|
81 |
+
return `${HF_INFERENCE_API_BASE_URL}/models/${model}`;
|
82 |
+
})();
|
83 |
+
|
84 |
+
// Let users configure credentials, or disable them all together (or keep default behavior).
|
85 |
+
// ---
|
86 |
+
// This used to be an internal property only and never exposed to users. This means that most usages will never define this value
|
87 |
+
// So in order to make this backwards compatible, if it's undefined we go to "same-origin" (default behaviour before).
|
88 |
+
// If it's a boolean and set to true then set to "include". If false, don't define credentials at all (useful for edge runtimes)
|
89 |
+
// Then finally, if it's a string, use it as-is.
|
90 |
+
let credentials: RequestCredentials | undefined;
|
91 |
+
if (typeof includeCredentials === "string") {
|
92 |
+
credentials = includeCredentials as RequestCredentials;
|
93 |
+
} else if (typeof includeCredentials === "boolean") {
|
94 |
+
credentials = includeCredentials ? "include" : undefined;
|
95 |
+
} else if (includeCredentials === undefined) {
|
96 |
+
credentials = "same-origin";
|
97 |
+
}
|
98 |
+
|
99 |
+
const info: RequestInit = {
|
100 |
+
headers,
|
101 |
+
method: "POST",
|
102 |
+
body: binary
|
103 |
+
? args.data
|
104 |
+
: JSON.stringify({
|
105 |
+
...otherArgs,
|
106 |
+
options: options && otherOptions,
|
107 |
+
}),
|
108 |
+
credentials,
|
109 |
+
signal: options?.signal,
|
110 |
+
};
|
111 |
+
|
112 |
+
return { url, info };
|
113 |
+
}
|
packages/inference/src/tasks/audio/audioClassification.ts
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type AudioClassificationArgs = BaseArgs & {
|
6 |
+
/**
|
7 |
+
* Binary audio data
|
8 |
+
*/
|
9 |
+
data: Blob | ArrayBuffer;
|
10 |
+
};
|
11 |
+
|
12 |
+
export interface AudioClassificationOutputValue {
|
13 |
+
/**
|
14 |
+
* The label for the class (model specific)
|
15 |
+
*/
|
16 |
+
label: string;
|
17 |
+
|
18 |
+
/**
|
19 |
+
* A float that represents how likely it is that the audio file belongs to this class.
|
20 |
+
*/
|
21 |
+
score: number;
|
22 |
+
}
|
23 |
+
|
24 |
+
export type AudioClassificationReturn = AudioClassificationOutputValue[];
|
25 |
+
|
26 |
+
/**
|
27 |
+
* This task reads some audio input and outputs the likelihood of classes.
|
28 |
+
* Recommended model: superb/hubert-large-superb-er
|
29 |
+
*/
|
30 |
+
export async function audioClassification(
|
31 |
+
args: AudioClassificationArgs,
|
32 |
+
options?: Options
|
33 |
+
): Promise<AudioClassificationReturn> {
|
34 |
+
const res = await request<AudioClassificationReturn>(args, {
|
35 |
+
...options,
|
36 |
+
taskHint: "audio-classification",
|
37 |
+
});
|
38 |
+
const isValidOutput =
|
39 |
+
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
40 |
+
if (!isValidOutput) {
|
41 |
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
42 |
+
}
|
43 |
+
return res;
|
44 |
+
}
|
packages/inference/src/tasks/audio/audioToAudio.ts
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type AudioToAudioArgs = BaseArgs & {
|
6 |
+
/**
|
7 |
+
* Binary audio data
|
8 |
+
*/
|
9 |
+
data: Blob | ArrayBuffer;
|
10 |
+
};
|
11 |
+
|
12 |
+
export interface AudioToAudioOutputValue {
|
13 |
+
/**
|
14 |
+
* The label for the audio output (model specific)
|
15 |
+
*/
|
16 |
+
label: string;
|
17 |
+
|
18 |
+
/**
|
19 |
+
* Base64 encoded audio output.
|
20 |
+
*/
|
21 |
+
blob: string;
|
22 |
+
|
23 |
+
/**
|
24 |
+
* Content-type for blob, e.g. audio/flac
|
25 |
+
*/
|
26 |
+
"content-type": string;
|
27 |
+
}
|
28 |
+
|
29 |
+
export type AudioToAudioReturn = AudioToAudioOutputValue[];
|
30 |
+
|
31 |
+
/**
|
32 |
+
* This task reads some audio input and outputs one or multiple audio files.
|
33 |
+
* Example model: speechbrain/sepformer-wham does audio source separation.
|
34 |
+
*/
|
35 |
+
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioReturn> {
|
36 |
+
const res = await request<AudioToAudioReturn>(args, {
|
37 |
+
...options,
|
38 |
+
taskHint: "audio-to-audio",
|
39 |
+
});
|
40 |
+
const isValidOutput =
|
41 |
+
Array.isArray(res) &&
|
42 |
+
res.every(
|
43 |
+
(x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
|
44 |
+
);
|
45 |
+
if (!isValidOutput) {
|
46 |
+
throw new InferenceOutputError("Expected Array<{label: string, blob: string, content-type: string}>");
|
47 |
+
}
|
48 |
+
return res;
|
49 |
+
}
|
packages/inference/src/tasks/audio/automaticSpeechRecognition.ts
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type AutomaticSpeechRecognitionArgs = BaseArgs & {
|
6 |
+
/**
|
7 |
+
* Binary audio data
|
8 |
+
*/
|
9 |
+
data: Blob | ArrayBuffer;
|
10 |
+
};
|
11 |
+
|
12 |
+
export interface AutomaticSpeechRecognitionOutput {
|
13 |
+
/**
|
14 |
+
* The text that was recognized from the audio
|
15 |
+
*/
|
16 |
+
text: string;
|
17 |
+
}
|
18 |
+
|
19 |
+
/**
|
20 |
+
* This task reads some audio input and outputs the said words within the audio files.
|
21 |
+
* Recommended model (english language): facebook/wav2vec2-large-960h-lv60-self
|
22 |
+
*/
|
23 |
+
export async function automaticSpeechRecognition(
|
24 |
+
args: AutomaticSpeechRecognitionArgs,
|
25 |
+
options?: Options
|
26 |
+
): Promise<AutomaticSpeechRecognitionOutput> {
|
27 |
+
const res = await request<AutomaticSpeechRecognitionOutput>(args, {
|
28 |
+
...options,
|
29 |
+
taskHint: "automatic-speech-recognition",
|
30 |
+
});
|
31 |
+
const isValidOutput = typeof res?.text === "string";
|
32 |
+
if (!isValidOutput) {
|
33 |
+
throw new InferenceOutputError("Expected {text: string}");
|
34 |
+
}
|
35 |
+
return res;
|
36 |
+
}
|
packages/inference/src/tasks/audio/textToSpeech.ts
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type TextToSpeechArgs = BaseArgs & {
|
6 |
+
/**
|
7 |
+
* The text to generate an audio from
|
8 |
+
*/
|
9 |
+
inputs: string;
|
10 |
+
};
|
11 |
+
|
12 |
+
export type TextToSpeechOutput = Blob;
|
13 |
+
|
14 |
+
/**
|
15 |
+
* This task synthesize an audio of a voice pronouncing a given text.
|
16 |
+
* Recommended model: espnet/kan-bayashi_ljspeech_vits
|
17 |
+
*/
|
18 |
+
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<TextToSpeechOutput> {
|
19 |
+
const res = await request<TextToSpeechOutput>(args, {
|
20 |
+
...options,
|
21 |
+
taskHint: "text-to-speech",
|
22 |
+
});
|
23 |
+
const isValidOutput = res && res instanceof Blob;
|
24 |
+
if (!isValidOutput) {
|
25 |
+
throw new InferenceOutputError("Expected Blob");
|
26 |
+
}
|
27 |
+
return res;
|
28 |
+
}
|
packages/inference/src/tasks/custom/request.ts
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import type { InferenceTask, Options, RequestArgs } from "../../types";
|
2 |
+
import { makeRequestOptions } from "../../lib/makeRequestOptions";
|
3 |
+
|
4 |
+
/**
|
5 |
+
* Primitive to make custom calls to Inference Endpoints
|
6 |
+
*/
|
7 |
+
export async function request<T>(
|
8 |
+
args: RequestArgs,
|
9 |
+
options?: Options & {
|
10 |
+
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
11 |
+
task?: string | InferenceTask;
|
12 |
+
/** To load default model if needed */
|
13 |
+
taskHint?: InferenceTask;
|
14 |
+
}
|
15 |
+
): Promise<T> {
|
16 |
+
const { url, info } = await makeRequestOptions(args, options);
|
17 |
+
const response = await (options?.fetch ?? fetch)(url, info);
|
18 |
+
|
19 |
+
if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
|
20 |
+
return request(args, {
|
21 |
+
...options,
|
22 |
+
wait_for_model: true,
|
23 |
+
});
|
24 |
+
}
|
25 |
+
|
26 |
+
if (!response.ok) {
|
27 |
+
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
28 |
+
const output = await response.json();
|
29 |
+
if (output.error) {
|
30 |
+
throw new Error(output.error);
|
31 |
+
}
|
32 |
+
}
|
33 |
+
throw new Error("An error occurred while fetching the blob");
|
34 |
+
}
|
35 |
+
|
36 |
+
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
37 |
+
return await response.json();
|
38 |
+
}
|
39 |
+
|
40 |
+
return (await response.blob()) as T;
|
41 |
+
}
|
packages/inference/src/tasks/custom/streamingRequest.ts
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import type { InferenceTask, Options, RequestArgs } from "../../types";
|
2 |
+
import { makeRequestOptions } from "../../lib/makeRequestOptions";
|
3 |
+
import type { EventSourceMessage } from "../../vendor/fetch-event-source/parse";
|
4 |
+
import { getLines, getMessages } from "../../vendor/fetch-event-source/parse";
|
5 |
+
|
6 |
+
/**
|
7 |
+
* Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator
|
8 |
+
*/
|
9 |
+
export async function* streamingRequest<T>(
|
10 |
+
args: RequestArgs,
|
11 |
+
options?: Options & {
|
12 |
+
/** When a model can be used for multiple tasks, and we want to run a non-default task */
|
13 |
+
task?: string | InferenceTask;
|
14 |
+
/** To load default model if needed */
|
15 |
+
taskHint?: InferenceTask;
|
16 |
+
}
|
17 |
+
): AsyncGenerator<T> {
|
18 |
+
const { url, info } = await makeRequestOptions({ ...args, stream: true }, options);
|
19 |
+
const response = await (options?.fetch ?? fetch)(url, info);
|
20 |
+
|
21 |
+
if (options?.retry_on_error !== false && response.status === 503 && !options?.wait_for_model) {
|
22 |
+
return streamingRequest(args, {
|
23 |
+
...options,
|
24 |
+
wait_for_model: true,
|
25 |
+
});
|
26 |
+
}
|
27 |
+
if (!response.ok) {
|
28 |
+
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
29 |
+
const output = await response.json();
|
30 |
+
if (output.error) {
|
31 |
+
throw new Error(output.error);
|
32 |
+
}
|
33 |
+
}
|
34 |
+
|
35 |
+
throw new Error(`Server response contains error: ${response.status}`);
|
36 |
+
}
|
37 |
+
if (!response.headers.get("content-type")?.startsWith("text/event-stream")) {
|
38 |
+
throw new Error(
|
39 |
+
`Server does not support event stream content type, it returned ` + response.headers.get("content-type")
|
40 |
+
);
|
41 |
+
}
|
42 |
+
|
43 |
+
if (!response.body) {
|
44 |
+
return;
|
45 |
+
}
|
46 |
+
|
47 |
+
const reader = response.body.getReader();
|
48 |
+
let events: EventSourceMessage[] = [];
|
49 |
+
|
50 |
+
const onEvent = (event: EventSourceMessage) => {
|
51 |
+
// accumulate events in array
|
52 |
+
events.push(event);
|
53 |
+
};
|
54 |
+
|
55 |
+
const onChunk = getLines(
|
56 |
+
getMessages(
|
57 |
+
() => {},
|
58 |
+
() => {},
|
59 |
+
onEvent
|
60 |
+
)
|
61 |
+
);
|
62 |
+
|
63 |
+
try {
|
64 |
+
while (true) {
|
65 |
+
const { done, value } = await reader.read();
|
66 |
+
if (done) return;
|
67 |
+
onChunk(value);
|
68 |
+
for (const event of events) {
|
69 |
+
if (event.data.length > 0) {
|
70 |
+
const data = JSON.parse(event.data);
|
71 |
+
if (typeof data === "object" && data !== null && "error" in data) {
|
72 |
+
throw new Error(data.error);
|
73 |
+
}
|
74 |
+
yield data as T;
|
75 |
+
}
|
76 |
+
}
|
77 |
+
events = [];
|
78 |
+
}
|
79 |
+
} finally {
|
80 |
+
reader.releaseLock();
|
81 |
+
}
|
82 |
+
}
|
packages/inference/src/tasks/cv/imageClassification.ts
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type ImageClassificationArgs = BaseArgs & {
|
6 |
+
/**
|
7 |
+
* Binary image data
|
8 |
+
*/
|
9 |
+
data: Blob | ArrayBuffer;
|
10 |
+
};
|
11 |
+
|
12 |
+
export interface ImageClassificationOutputValue {
|
13 |
+
/**
|
14 |
+
* The label for the class (model specific)
|
15 |
+
*/
|
16 |
+
label: string;
|
17 |
+
/**
|
18 |
+
* A float that represents how likely it is that the image file belongs to this class.
|
19 |
+
*/
|
20 |
+
score: number;
|
21 |
+
}
|
22 |
+
|
23 |
+
export type ImageClassificationOutput = ImageClassificationOutputValue[];
|
24 |
+
|
25 |
+
/**
|
26 |
+
* This task reads some image input and outputs the likelihood of classes.
|
27 |
+
* Recommended model: google/vit-base-patch16-224
|
28 |
+
*/
|
29 |
+
export async function imageClassification(
|
30 |
+
args: ImageClassificationArgs,
|
31 |
+
options?: Options
|
32 |
+
): Promise<ImageClassificationOutput> {
|
33 |
+
const res = await request<ImageClassificationOutput>(args, {
|
34 |
+
...options,
|
35 |
+
taskHint: "image-classification",
|
36 |
+
});
|
37 |
+
const isValidOutput =
|
38 |
+
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
39 |
+
if (!isValidOutput) {
|
40 |
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
41 |
+
}
|
42 |
+
return res;
|
43 |
+
}
|
packages/inference/src/tasks/cv/imageSegmentation.ts
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type ImageSegmentationArgs = BaseArgs & {
|
6 |
+
/**
|
7 |
+
* Binary image data
|
8 |
+
*/
|
9 |
+
data: Blob | ArrayBuffer;
|
10 |
+
};
|
11 |
+
|
12 |
+
export interface ImageSegmentationOutputValue {
|
13 |
+
/**
|
14 |
+
* The label for the class (model specific) of a segment.
|
15 |
+
*/
|
16 |
+
label: string;
|
17 |
+
/**
|
18 |
+
* A str (base64 str of a single channel black-and-white img) representing the mask of a segment.
|
19 |
+
*/
|
20 |
+
mask: string;
|
21 |
+
/**
|
22 |
+
* A float that represents how likely it is that the detected object belongs to the given class.
|
23 |
+
*/
|
24 |
+
score: number;
|
25 |
+
}
|
26 |
+
|
27 |
+
export type ImageSegmentationOutput = ImageSegmentationOutputValue[];
|
28 |
+
|
29 |
+
/**
|
30 |
+
* This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
|
31 |
+
* Recommended model: facebook/detr-resnet-50-panoptic
|
32 |
+
*/
|
33 |
+
export async function imageSegmentation(
|
34 |
+
args: ImageSegmentationArgs,
|
35 |
+
options?: Options
|
36 |
+
): Promise<ImageSegmentationOutput> {
|
37 |
+
const res = await request<ImageSegmentationOutput>(args, {
|
38 |
+
...options,
|
39 |
+
taskHint: "image-segmentation",
|
40 |
+
});
|
41 |
+
const isValidOutput =
|
42 |
+
Array.isArray(res) &&
|
43 |
+
res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
|
44 |
+
if (!isValidOutput) {
|
45 |
+
throw new InferenceOutputError("Expected Array<{label: string, mask: string, score: number}>");
|
46 |
+
}
|
47 |
+
return res;
|
48 |
+
}
|
packages/inference/src/tasks/cv/imageToImage.ts
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options, RequestArgs } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
import { base64FromBytes } from "../../../../shared";
|
5 |
+
|
6 |
+
export type ImageToImageArgs = BaseArgs & {
|
7 |
+
/**
|
8 |
+
* The initial image condition
|
9 |
+
*
|
10 |
+
**/
|
11 |
+
inputs: Blob | ArrayBuffer;
|
12 |
+
|
13 |
+
parameters?: {
|
14 |
+
/**
|
15 |
+
* The text prompt to guide the image generation.
|
16 |
+
*/
|
17 |
+
prompt?: string;
|
18 |
+
/**
|
19 |
+
* strengh param only works for SD img2img and alt diffusion img2img models
|
20 |
+
* Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
21 |
+
* will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
22 |
+
* denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
23 |
+
* be maximum and the denoising process will run for the full number of iterations specified in
|
24 |
+
* `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
25 |
+
**/
|
26 |
+
strength?: number;
|
27 |
+
/**
|
28 |
+
* An optional negative prompt for the image generation
|
29 |
+
*/
|
30 |
+
negative_prompt?: string;
|
31 |
+
/**
|
32 |
+
* The height in pixels of the generated image
|
33 |
+
*/
|
34 |
+
height?: number;
|
35 |
+
/**
|
36 |
+
* The width in pixels of the generated image
|
37 |
+
*/
|
38 |
+
width?: number;
|
39 |
+
/**
|
40 |
+
* The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
|
41 |
+
*/
|
42 |
+
num_inference_steps?: number;
|
43 |
+
/**
|
44 |
+
* Guidance scale: Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
45 |
+
*/
|
46 |
+
guidance_scale?: number;
|
47 |
+
/**
|
48 |
+
* guess_mode only works for ControlNet models, defaults to False In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
|
49 |
+
* you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
|
50 |
+
*/
|
51 |
+
guess_mode?: boolean;
|
52 |
+
};
|
53 |
+
};
|
54 |
+
|
55 |
+
export type ImageToImageOutput = Blob;
|
56 |
+
|
57 |
+
/**
|
58 |
+
* This task reads some text input and outputs an image.
|
59 |
+
* Recommended model: lllyasviel/sd-controlnet-depth
|
60 |
+
*/
|
61 |
+
export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<ImageToImageOutput> {
|
62 |
+
let reqArgs: RequestArgs;
|
63 |
+
if (!args.parameters) {
|
64 |
+
reqArgs = {
|
65 |
+
accessToken: args.accessToken,
|
66 |
+
model: args.model,
|
67 |
+
data: args.inputs,
|
68 |
+
};
|
69 |
+
} else {
|
70 |
+
reqArgs = {
|
71 |
+
...args,
|
72 |
+
inputs: base64FromBytes(
|
73 |
+
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())
|
74 |
+
),
|
75 |
+
};
|
76 |
+
}
|
77 |
+
const res = await request<ImageToImageOutput>(reqArgs, {
|
78 |
+
...options,
|
79 |
+
taskHint: "image-to-image",
|
80 |
+
});
|
81 |
+
const isValidOutput = res && res instanceof Blob;
|
82 |
+
if (!isValidOutput) {
|
83 |
+
throw new InferenceOutputError("Expected Blob");
|
84 |
+
}
|
85 |
+
return res;
|
86 |
+
}
|
packages/inference/src/tasks/cv/imageToText.ts
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type ImageToTextArgs = BaseArgs & {
|
6 |
+
/**
|
7 |
+
* Binary image data
|
8 |
+
*/
|
9 |
+
data: Blob | ArrayBuffer;
|
10 |
+
};
|
11 |
+
|
12 |
+
export interface ImageToTextOutput {
|
13 |
+
/**
|
14 |
+
* The generated caption
|
15 |
+
*/
|
16 |
+
generated_text: string;
|
17 |
+
}
|
18 |
+
|
19 |
+
/**
|
20 |
+
* This task reads some image input and outputs the text caption.
|
21 |
+
*/
|
22 |
+
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
|
23 |
+
const res = (
|
24 |
+
await request<[ImageToTextOutput]>(args, {
|
25 |
+
...options,
|
26 |
+
taskHint: "image-to-text",
|
27 |
+
})
|
28 |
+
)?.[0];
|
29 |
+
|
30 |
+
if (typeof res?.generated_text !== "string") {
|
31 |
+
throw new InferenceOutputError("Expected {generated_text: string}");
|
32 |
+
}
|
33 |
+
|
34 |
+
return res;
|
35 |
+
}
|
packages/inference/src/tasks/cv/objectDetection.ts
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { request } from "../custom/request";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
4 |
+
|
5 |
+
export type ObjectDetectionArgs = BaseArgs & {
|
6 |
+
/**
|
7 |
+
* Binary image data
|
8 |
+
*/
|
9 |
+
data: Blob | ArrayBuffer;
|
10 |
+
};
|
11 |
+
|
12 |
+
export interface ObjectDetectionOutputValue {
|
13 |
+
/**
|
14 |
+
* A dict (with keys [xmin,ymin,xmax,ymax]) representing the bounding box of a detected object.
|
15 |
+
*/
|
16 |
+
box: {
|
17 |
+
xmax: number;
|
18 |
+
xmin: number;
|
19 |
+
ymax: number;
|
20 |
+
ymin: number;
|
21 |
+
};
|
22 |
+
/**
|
23 |
+
* The label for the class (model specific) of a detected object.
|
24 |
+
*/
|
25 |
+
label: string;
|
26 |
+
|
27 |
+
/**
|
28 |
+
* A float that represents how likely it is that the detected object belongs to the given class.
|
29 |
+
*/
|
30 |
+
score: number;
|
31 |
+
}
|
32 |
+
|
33 |
+
export type ObjectDetectionOutput = ObjectDetectionOutputValue[];
|
34 |
+
|
35 |
+
/**
|
36 |
+
* This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
|
37 |
+
* Recommended model: facebook/detr-resnet-50
|
38 |
+
*/
|
39 |
+
export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
|
40 |
+
const res = await request<ObjectDetectionOutput>(args, {
|
41 |
+
...options,
|
42 |
+
taskHint: "object-detection",
|
43 |
+
});
|
44 |
+
const isValidOutput =
|
45 |
+
Array.isArray(res) &&
|
46 |
+
res.every(
|
47 |
+
(x) =>
|
48 |
+
typeof x.label === "string" &&
|
49 |
+
typeof x.score === "number" &&
|
50 |
+
typeof x.box.xmin === "number" &&
|
51 |
+
typeof x.box.ymin === "number" &&
|
52 |
+
typeof x.box.xmax === "number" &&
|
53 |
+
typeof x.box.ymax === "number"
|
54 |
+
);
|
55 |
+
if (!isValidOutput) {
|
56 |
+
throw new InferenceOutputError(
|
57 |
+
"Expected Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
|
58 |
+
);
|
59 |
+
}
|
60 |
+
return res;
|
61 |
+
}
|
packages/inference/src/tasks/cv/textToImage.ts
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type TextToImageArgs = BaseArgs & {
|
6 |
+
/**
|
7 |
+
* The text to generate an image from
|
8 |
+
*/
|
9 |
+
inputs: string;
|
10 |
+
|
11 |
+
parameters?: {
|
12 |
+
/**
|
13 |
+
* An optional negative prompt for the image generation
|
14 |
+
*/
|
15 |
+
negative_prompt?: string;
|
16 |
+
/**
|
17 |
+
* The height in pixels of the generated image
|
18 |
+
*/
|
19 |
+
height?: number;
|
20 |
+
/**
|
21 |
+
* The width in pixels of the generated image
|
22 |
+
*/
|
23 |
+
width?: number;
|
24 |
+
/**
|
25 |
+
* The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
|
26 |
+
*/
|
27 |
+
num_inference_steps?: number;
|
28 |
+
/**
|
29 |
+
* Guidance scale: Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
30 |
+
*/
|
31 |
+
guidance_scale?: number;
|
32 |
+
};
|
33 |
+
};
|
34 |
+
|
35 |
+
export type TextToImageOutput = Blob;
|
36 |
+
|
37 |
+
/**
|
38 |
+
* This task reads some text input and outputs an image.
|
39 |
+
* Recommended model: stabilityai/stable-diffusion-2
|
40 |
+
*/
|
41 |
+
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput> {
|
42 |
+
const res = await request<TextToImageOutput>(args, {
|
43 |
+
...options,
|
44 |
+
taskHint: "text-to-image",
|
45 |
+
});
|
46 |
+
const isValidOutput = res && res instanceof Blob;
|
47 |
+
if (!isValidOutput) {
|
48 |
+
throw new InferenceOutputError("Expected Blob");
|
49 |
+
}
|
50 |
+
return res;
|
51 |
+
}
|
packages/inference/src/tasks/cv/zeroShotImageClassification.ts
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
import type { RequestArgs } from "../../types";
|
5 |
+
import { base64FromBytes } from "../../../../shared";
|
6 |
+
|
7 |
+
export type ZeroShotImageClassificationArgs = BaseArgs & {
|
8 |
+
inputs: {
|
9 |
+
/**
|
10 |
+
* Binary image data
|
11 |
+
*/
|
12 |
+
image: Blob | ArrayBuffer;
|
13 |
+
};
|
14 |
+
parameters: {
|
15 |
+
/**
|
16 |
+
* A list of strings that are potential classes for inputs. (max 10)
|
17 |
+
*/
|
18 |
+
candidate_labels: string[];
|
19 |
+
};
|
20 |
+
};
|
21 |
+
|
22 |
+
export interface ZeroShotImageClassificationOutputValue {
|
23 |
+
label: string;
|
24 |
+
score: number;
|
25 |
+
}
|
26 |
+
|
27 |
+
export type ZeroShotImageClassificationOutput = ZeroShotImageClassificationOutputValue[];
|
28 |
+
|
29 |
+
/**
|
30 |
+
* Classify an image to specified classes.
|
31 |
+
* Recommended model: openai/clip-vit-large-patch14-336
|
32 |
+
*/
|
33 |
+
export async function zeroShotImageClassification(
|
34 |
+
args: ZeroShotImageClassificationArgs,
|
35 |
+
options?: Options
|
36 |
+
): Promise<ZeroShotImageClassificationOutput> {
|
37 |
+
const reqArgs: RequestArgs = {
|
38 |
+
...args,
|
39 |
+
inputs: {
|
40 |
+
image: base64FromBytes(
|
41 |
+
new Uint8Array(
|
42 |
+
args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
|
43 |
+
)
|
44 |
+
),
|
45 |
+
},
|
46 |
+
} as RequestArgs;
|
47 |
+
|
48 |
+
const res = await request<ZeroShotImageClassificationOutput>(reqArgs, {
|
49 |
+
...options,
|
50 |
+
taskHint: "zero-shot-image-classification",
|
51 |
+
});
|
52 |
+
const isValidOutput =
|
53 |
+
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
|
54 |
+
if (!isValidOutput) {
|
55 |
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
56 |
+
}
|
57 |
+
return res;
|
58 |
+
}
|
packages/inference/src/tasks/index.ts
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Custom tasks with arbitrary inputs and outputs
|
2 |
+
export * from "./custom/request";
|
3 |
+
export * from "./custom/streamingRequest";
|
4 |
+
|
5 |
+
// Audio tasks
|
6 |
+
export * from "./audio/audioClassification";
|
7 |
+
export * from "./audio/automaticSpeechRecognition";
|
8 |
+
export * from "./audio/textToSpeech";
|
9 |
+
export * from "./audio/audioToAudio";
|
10 |
+
|
11 |
+
// Computer Vision tasks
|
12 |
+
export * from "./cv/imageClassification";
|
13 |
+
export * from "./cv/imageSegmentation";
|
14 |
+
export * from "./cv/imageToText";
|
15 |
+
export * from "./cv/objectDetection";
|
16 |
+
export * from "./cv/textToImage";
|
17 |
+
export * from "./cv/imageToImage";
|
18 |
+
export * from "./cv/zeroShotImageClassification";
|
19 |
+
|
20 |
+
// Natural Language Processing tasks
|
21 |
+
export * from "./nlp/featureExtraction";
|
22 |
+
export * from "./nlp/fillMask";
|
23 |
+
export * from "./nlp/questionAnswering";
|
24 |
+
export * from "./nlp/sentenceSimilarity";
|
25 |
+
export * from "./nlp/summarization";
|
26 |
+
export * from "./nlp/tableQuestionAnswering";
|
27 |
+
export * from "./nlp/textClassification";
|
28 |
+
export * from "./nlp/textGeneration";
|
29 |
+
export * from "./nlp/textGenerationStream";
|
30 |
+
export * from "./nlp/tokenClassification";
|
31 |
+
export * from "./nlp/translation";
|
32 |
+
export * from "./nlp/zeroShotClassification";
|
33 |
+
|
34 |
+
// Multimodal tasks
|
35 |
+
export * from "./multimodal/documentQuestionAnswering";
|
36 |
+
export * from "./multimodal/visualQuestionAnswering";
|
37 |
+
|
38 |
+
// Tabular tasks
|
39 |
+
export * from "./tabular/tabularRegression";
|
40 |
+
export * from "./tabular/tabularClassification";
|
packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
import type { RequestArgs } from "../../types";
|
5 |
+
import { base64FromBytes } from "../../../../shared";
|
6 |
+
import { toArray } from "../../utils/toArray";
|
7 |
+
|
8 |
+
export type DocumentQuestionAnsweringArgs = BaseArgs & {
|
9 |
+
inputs: {
|
10 |
+
/**
|
11 |
+
* Raw image
|
12 |
+
*
|
13 |
+
* You can use native `File` in browsers, or `new Blob([buffer])` in node, or for a base64 image `new Blob([btoa(base64String)])`, or even `await (await fetch('...)).blob()`
|
14 |
+
**/
|
15 |
+
image: Blob | ArrayBuffer;
|
16 |
+
question: string;
|
17 |
+
};
|
18 |
+
};
|
19 |
+
|
20 |
+
export interface DocumentQuestionAnsweringOutput {
|
21 |
+
/**
|
22 |
+
* A string that’s the answer within the document.
|
23 |
+
*/
|
24 |
+
answer: string;
|
25 |
+
/**
|
26 |
+
* ?
|
27 |
+
*/
|
28 |
+
end?: number;
|
29 |
+
/**
|
30 |
+
* A float that represents how likely that the answer is correct
|
31 |
+
*/
|
32 |
+
score?: number;
|
33 |
+
/**
|
34 |
+
* ?
|
35 |
+
*/
|
36 |
+
start?: number;
|
37 |
+
}
|
38 |
+
|
39 |
+
/**
|
40 |
+
* Answers a question on a document image. Recommended model: impira/layoutlm-document-qa.
|
41 |
+
*/
|
42 |
+
export async function documentQuestionAnswering(
|
43 |
+
args: DocumentQuestionAnsweringArgs,
|
44 |
+
options?: Options
|
45 |
+
): Promise<DocumentQuestionAnsweringOutput> {
|
46 |
+
const reqArgs: RequestArgs = {
|
47 |
+
...args,
|
48 |
+
inputs: {
|
49 |
+
question: args.inputs.question,
|
50 |
+
// convert Blob or ArrayBuffer to base64
|
51 |
+
image: base64FromBytes(
|
52 |
+
new Uint8Array(
|
53 |
+
args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
|
54 |
+
)
|
55 |
+
),
|
56 |
+
},
|
57 |
+
} as RequestArgs;
|
58 |
+
const res = toArray(
|
59 |
+
await request<[DocumentQuestionAnsweringOutput] | DocumentQuestionAnsweringOutput>(reqArgs, {
|
60 |
+
...options,
|
61 |
+
taskHint: "document-question-answering",
|
62 |
+
})
|
63 |
+
)?.[0];
|
64 |
+
const isValidOutput =
|
65 |
+
typeof res?.answer === "string" &&
|
66 |
+
(typeof res.end === "number" || typeof res.end === "undefined") &&
|
67 |
+
(typeof res.score === "number" || typeof res.score === "undefined") &&
|
68 |
+
(typeof res.start === "number" || typeof res.start === "undefined");
|
69 |
+
if (!isValidOutput) {
|
70 |
+
throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
|
71 |
+
}
|
72 |
+
return res;
|
73 |
+
}
|
packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options, RequestArgs } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
import { base64FromBytes } from "../../../../shared";
|
5 |
+
|
6 |
+
export type VisualQuestionAnsweringArgs = BaseArgs & {
|
7 |
+
inputs: {
|
8 |
+
/**
|
9 |
+
* Raw image
|
10 |
+
*
|
11 |
+
* You can use native `File` in browsers, or `new Blob([buffer])` in node, or for a base64 image `new Blob([btoa(base64String)])`, or even `await (await fetch('...)).blob()`
|
12 |
+
**/
|
13 |
+
image: Blob | ArrayBuffer;
|
14 |
+
question: string;
|
15 |
+
};
|
16 |
+
};
|
17 |
+
|
18 |
+
export interface VisualQuestionAnsweringOutput {
|
19 |
+
/**
|
20 |
+
* A string that’s the answer to a visual question.
|
21 |
+
*/
|
22 |
+
answer: string;
|
23 |
+
/**
|
24 |
+
* Answer correctness score.
|
25 |
+
*/
|
26 |
+
score: number;
|
27 |
+
}
|
28 |
+
|
29 |
+
/**
|
30 |
+
* Answers a question on an image. Recommended model: dandelin/vilt-b32-finetuned-vqa.
|
31 |
+
*/
|
32 |
+
export async function visualQuestionAnswering(
|
33 |
+
args: VisualQuestionAnsweringArgs,
|
34 |
+
options?: Options
|
35 |
+
): Promise<VisualQuestionAnsweringOutput> {
|
36 |
+
const reqArgs: RequestArgs = {
|
37 |
+
...args,
|
38 |
+
inputs: {
|
39 |
+
question: args.inputs.question,
|
40 |
+
// convert Blob or ArrayBuffer to base64
|
41 |
+
image: base64FromBytes(
|
42 |
+
new Uint8Array(
|
43 |
+
args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.image.arrayBuffer()
|
44 |
+
)
|
45 |
+
),
|
46 |
+
},
|
47 |
+
} as RequestArgs;
|
48 |
+
const res = (
|
49 |
+
await request<[VisualQuestionAnsweringOutput]>(reqArgs, {
|
50 |
+
...options,
|
51 |
+
taskHint: "visual-question-answering",
|
52 |
+
})
|
53 |
+
)?.[0];
|
54 |
+
const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number";
|
55 |
+
if (!isValidOutput) {
|
56 |
+
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
|
57 |
+
}
|
58 |
+
return res;
|
59 |
+
}
|
packages/inference/src/tasks/nlp/featureExtraction.ts
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import { getDefaultTask } from "../../lib/getDefaultTask";
|
3 |
+
import type { BaseArgs, Options } from "../../types";
|
4 |
+
import { request } from "../custom/request";
|
5 |
+
|
6 |
+
export type FeatureExtractionArgs = BaseArgs & {
|
7 |
+
/**
|
8 |
+
* The inputs is a string or a list of strings to get the features from.
|
9 |
+
*
|
10 |
+
* inputs: "That is a happy person",
|
11 |
+
*
|
12 |
+
*/
|
13 |
+
inputs: string | string[];
|
14 |
+
};
|
15 |
+
|
16 |
+
/**
|
17 |
+
* Returned values are a multidimensional array of floats (dimension depending on if you sent a string or a list of string, and if the automatic reduction, usually mean_pooling for instance was applied for you or not. This should be explained on the model's README).
|
18 |
+
*/
|
19 |
+
export type FeatureExtractionOutput = (number | number[] | number[][])[];
|
20 |
+
|
21 |
+
/**
|
22 |
+
* This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search.
|
23 |
+
*/
|
24 |
+
export async function featureExtraction(
|
25 |
+
args: FeatureExtractionArgs,
|
26 |
+
options?: Options
|
27 |
+
): Promise<FeatureExtractionOutput> {
|
28 |
+
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : undefined;
|
29 |
+
|
30 |
+
const res = await request<FeatureExtractionOutput>(args, {
|
31 |
+
...options,
|
32 |
+
taskHint: "feature-extraction",
|
33 |
+
...(defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }),
|
34 |
+
});
|
35 |
+
let isValidOutput = true;
|
36 |
+
|
37 |
+
const isNumArrayRec = (arr: unknown[], maxDepth: number, curDepth = 0): boolean => {
|
38 |
+
if (curDepth > maxDepth) return false;
|
39 |
+
if (arr.every((x) => Array.isArray(x))) {
|
40 |
+
return arr.every((x) => isNumArrayRec(x as unknown[], maxDepth, curDepth + 1));
|
41 |
+
} else {
|
42 |
+
return arr.every((x) => typeof x === "number");
|
43 |
+
}
|
44 |
+
};
|
45 |
+
|
46 |
+
isValidOutput = Array.isArray(res) && isNumArrayRec(res, 3, 0);
|
47 |
+
|
48 |
+
if (!isValidOutput) {
|
49 |
+
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
|
50 |
+
}
|
51 |
+
return res;
|
52 |
+
}
|
packages/inference/src/tasks/nlp/fillMask.ts
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type FillMaskArgs = BaseArgs & {
|
6 |
+
inputs: string;
|
7 |
+
};
|
8 |
+
|
9 |
+
export type FillMaskOutput = {
|
10 |
+
/**
|
11 |
+
* The probability for this token.
|
12 |
+
*/
|
13 |
+
score: number;
|
14 |
+
/**
|
15 |
+
* The actual sequence of tokens that ran against the model (may contain special tokens)
|
16 |
+
*/
|
17 |
+
sequence: string;
|
18 |
+
/**
|
19 |
+
* The id of the token
|
20 |
+
*/
|
21 |
+
token: number;
|
22 |
+
/**
|
23 |
+
* The string representation of the token
|
24 |
+
*/
|
25 |
+
token_str: string;
|
26 |
+
}[];
|
27 |
+
|
28 |
+
/**
|
29 |
+
* Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
|
30 |
+
*/
|
31 |
+
export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> {
|
32 |
+
const res = await request<FillMaskOutput>(args, {
|
33 |
+
...options,
|
34 |
+
taskHint: "fill-mask",
|
35 |
+
});
|
36 |
+
const isValidOutput =
|
37 |
+
Array.isArray(res) &&
|
38 |
+
res.every(
|
39 |
+
(x) =>
|
40 |
+
typeof x.score === "number" &&
|
41 |
+
typeof x.sequence === "string" &&
|
42 |
+
typeof x.token === "number" &&
|
43 |
+
typeof x.token_str === "string"
|
44 |
+
);
|
45 |
+
if (!isValidOutput) {
|
46 |
+
throw new InferenceOutputError(
|
47 |
+
"Expected Array<{score: number, sequence: string, token: number, token_str: string}>"
|
48 |
+
);
|
49 |
+
}
|
50 |
+
return res;
|
51 |
+
}
|
packages/inference/src/tasks/nlp/questionAnswering.ts
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type QuestionAnsweringArgs = BaseArgs & {
|
6 |
+
inputs: {
|
7 |
+
context: string;
|
8 |
+
question: string;
|
9 |
+
};
|
10 |
+
};
|
11 |
+
|
12 |
+
export interface QuestionAnsweringOutput {
|
13 |
+
/**
|
14 |
+
* A string that’s the answer within the text.
|
15 |
+
*/
|
16 |
+
answer: string;
|
17 |
+
/**
|
18 |
+
* The index (string wise) of the stop of the answer within context.
|
19 |
+
*/
|
20 |
+
end: number;
|
21 |
+
/**
|
22 |
+
* A float that represents how likely that the answer is correct
|
23 |
+
*/
|
24 |
+
score: number;
|
25 |
+
/**
|
26 |
+
* The index (string wise) of the start of the answer within context.
|
27 |
+
*/
|
28 |
+
start: number;
|
29 |
+
}
|
30 |
+
|
31 |
+
/**
|
32 |
+
* Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
|
33 |
+
*/
|
34 |
+
export async function questionAnswering(
|
35 |
+
args: QuestionAnsweringArgs,
|
36 |
+
options?: Options
|
37 |
+
): Promise<QuestionAnsweringOutput> {
|
38 |
+
const res = await request<QuestionAnsweringOutput>(args, {
|
39 |
+
...options,
|
40 |
+
taskHint: "question-answering",
|
41 |
+
});
|
42 |
+
const isValidOutput =
|
43 |
+
typeof res === "object" &&
|
44 |
+
!!res &&
|
45 |
+
typeof res.answer === "string" &&
|
46 |
+
typeof res.end === "number" &&
|
47 |
+
typeof res.score === "number" &&
|
48 |
+
typeof res.start === "number";
|
49 |
+
if (!isValidOutput) {
|
50 |
+
throw new InferenceOutputError("Expected {answer: string, end: number, score: number, start: number}");
|
51 |
+
}
|
52 |
+
return res;
|
53 |
+
}
|
packages/inference/src/tasks/nlp/sentenceSimilarity.ts
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import { getDefaultTask } from "../../lib/getDefaultTask";
|
3 |
+
import type { BaseArgs, Options } from "../../types";
|
4 |
+
import { request } from "../custom/request";
|
5 |
+
|
6 |
+
export type SentenceSimilarityArgs = BaseArgs & {
|
7 |
+
/**
|
8 |
+
* The inputs vary based on the model.
|
9 |
+
*
|
10 |
+
* For example when using sentence-transformers/paraphrase-xlm-r-multilingual-v1 the inputs will have a `source_sentence` string and
|
11 |
+
* a `sentences` array of strings
|
12 |
+
*/
|
13 |
+
inputs: Record<string, unknown> | Record<string, unknown>[];
|
14 |
+
};
|
15 |
+
|
16 |
+
/**
|
17 |
+
* Returned values are a list of floats
|
18 |
+
*/
|
19 |
+
export type SentenceSimilarityOutput = number[];
|
20 |
+
|
21 |
+
/**
|
22 |
+
* Calculate the semantic similarity between one text and a list of other sentences by comparing their embeddings.
|
23 |
+
*/
|
24 |
+
export async function sentenceSimilarity(
|
25 |
+
args: SentenceSimilarityArgs,
|
26 |
+
options?: Options
|
27 |
+
): Promise<SentenceSimilarityOutput> {
|
28 |
+
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : undefined;
|
29 |
+
const res = await request<SentenceSimilarityOutput>(args, {
|
30 |
+
...options,
|
31 |
+
taskHint: "sentence-similarity",
|
32 |
+
...(defaultTask === "feature-extraction" && { forceTask: "sentence-similarity" }),
|
33 |
+
});
|
34 |
+
|
35 |
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
|
36 |
+
if (!isValidOutput) {
|
37 |
+
throw new InferenceOutputError("Expected number[]");
|
38 |
+
}
|
39 |
+
return res;
|
40 |
+
}
|
packages/inference/src/tasks/nlp/summarization.ts
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type SummarizationArgs = BaseArgs & {
|
6 |
+
/**
|
7 |
+
* A string to be summarized
|
8 |
+
*/
|
9 |
+
inputs: string;
|
10 |
+
parameters?: {
|
11 |
+
/**
|
12 |
+
* (Default: None). Integer to define the maximum length in tokens of the output summary.
|
13 |
+
*/
|
14 |
+
max_length?: number;
|
15 |
+
/**
|
16 |
+
* (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. Network can cause some overhead so it will be a soft limit.
|
17 |
+
*/
|
18 |
+
max_time?: number;
|
19 |
+
/**
|
20 |
+
* (Default: None). Integer to define the minimum length in tokens of the output summary.
|
21 |
+
*/
|
22 |
+
min_length?: number;
|
23 |
+
/**
|
24 |
+
* (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized to not be picked in successive generation passes.
|
25 |
+
*/
|
26 |
+
repetition_penalty?: number;
|
27 |
+
/**
|
28 |
+
* (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability.
|
29 |
+
*/
|
30 |
+
temperature?: number;
|
31 |
+
/**
|
32 |
+
* (Default: None). Integer to define the top tokens considered within the sample operation to create new text.
|
33 |
+
*/
|
34 |
+
top_k?: number;
|
35 |
+
/**
|
36 |
+
* (Default: None). Float to define the tokens that are within the sample operation of text generation. Add tokens in the sample for more probable to least probable until the sum of the probabilities is greater than top_p.
|
37 |
+
*/
|
38 |
+
top_p?: number;
|
39 |
+
};
|
40 |
+
};
|
41 |
+
|
42 |
+
export interface SummarizationOutput {
|
43 |
+
/**
|
44 |
+
* The string after translation
|
45 |
+
*/
|
46 |
+
summary_text: string;
|
47 |
+
}
|
48 |
+
|
49 |
+
/**
|
50 |
+
* This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
|
51 |
+
*/
|
52 |
+
export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> {
|
53 |
+
const res = await request<SummarizationOutput[]>(args, {
|
54 |
+
...options,
|
55 |
+
taskHint: "summarization",
|
56 |
+
});
|
57 |
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.summary_text === "string");
|
58 |
+
if (!isValidOutput) {
|
59 |
+
throw new InferenceOutputError("Expected Array<{summary_text: string}>");
|
60 |
+
}
|
61 |
+
return res?.[0];
|
62 |
+
}
|
packages/inference/src/tasks/nlp/tableQuestionAnswering.ts
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type TableQuestionAnsweringArgs = BaseArgs & {
|
6 |
+
inputs: {
|
7 |
+
/**
|
8 |
+
* The query in plain text that you want to ask the table
|
9 |
+
*/
|
10 |
+
query: string;
|
11 |
+
/**
|
12 |
+
* A table of data represented as a dict of list where entries are headers and the lists are all the values, all lists must have the same size.
|
13 |
+
*/
|
14 |
+
table: Record<string, string[]>;
|
15 |
+
};
|
16 |
+
};
|
17 |
+
|
18 |
+
export interface TableQuestionAnsweringOutput {
|
19 |
+
/**
|
20 |
+
* The aggregator used to get the answer
|
21 |
+
*/
|
22 |
+
aggregator: string;
|
23 |
+
/**
|
24 |
+
* The plaintext answer
|
25 |
+
*/
|
26 |
+
answer: string;
|
27 |
+
/**
|
28 |
+
* A list of coordinates of the cells contents
|
29 |
+
*/
|
30 |
+
cells: string[];
|
31 |
+
/**
|
32 |
+
* a list of coordinates of the cells referenced in the answer
|
33 |
+
*/
|
34 |
+
coordinates: number[][];
|
35 |
+
}
|
36 |
+
|
37 |
+
/**
|
38 |
+
* Don’t know SQL? Don’t want to dive into a large spreadsheet? Ask questions in plain english! Recommended model: google/tapas-base-finetuned-wtq.
|
39 |
+
*/
|
40 |
+
export async function tableQuestionAnswering(
|
41 |
+
args: TableQuestionAnsweringArgs,
|
42 |
+
options?: Options
|
43 |
+
): Promise<TableQuestionAnsweringOutput> {
|
44 |
+
const res = await request<TableQuestionAnsweringOutput>(args, {
|
45 |
+
...options,
|
46 |
+
taskHint: "table-question-answering",
|
47 |
+
});
|
48 |
+
const isValidOutput =
|
49 |
+
typeof res?.aggregator === "string" &&
|
50 |
+
typeof res.answer === "string" &&
|
51 |
+
Array.isArray(res.cells) &&
|
52 |
+
res.cells.every((x) => typeof x === "string") &&
|
53 |
+
Array.isArray(res.coordinates) &&
|
54 |
+
res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));
|
55 |
+
if (!isValidOutput) {
|
56 |
+
throw new InferenceOutputError(
|
57 |
+
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
|
58 |
+
);
|
59 |
+
}
|
60 |
+
return res;
|
61 |
+
}
|
packages/inference/src/tasks/nlp/textClassification.ts
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type TextClassificationArgs = BaseArgs & {
|
6 |
+
/**
|
7 |
+
* A string to be classified
|
8 |
+
*/
|
9 |
+
inputs: string;
|
10 |
+
};
|
11 |
+
|
12 |
+
export type TextClassificationOutput = {
|
13 |
+
/**
|
14 |
+
* The label for the class (model specific)
|
15 |
+
*/
|
16 |
+
label: string;
|
17 |
+
/**
|
18 |
+
* A floats that represents how likely is that the text belongs to this class.
|
19 |
+
*/
|
20 |
+
score: number;
|
21 |
+
}[];
|
22 |
+
|
23 |
+
/**
|
24 |
+
* Usually used for sentiment-analysis this will output the likelihood of classes of an input. Recommended model: distilbert-base-uncased-finetuned-sst-2-english
|
25 |
+
*/
|
26 |
+
export async function textClassification(
|
27 |
+
args: TextClassificationArgs,
|
28 |
+
options?: Options
|
29 |
+
): Promise<TextClassificationOutput> {
|
30 |
+
const res = (
|
31 |
+
await request<TextClassificationOutput[]>(args, {
|
32 |
+
...options,
|
33 |
+
taskHint: "text-classification",
|
34 |
+
})
|
35 |
+
)?.[0];
|
36 |
+
const isValidOutput =
|
37 |
+
Array.isArray(res) && res.every((x) => typeof x?.label === "string" && typeof x.score === "number");
|
38 |
+
if (!isValidOutput) {
|
39 |
+
throw new InferenceOutputError("Expected Array<{label: string, score: number}>");
|
40 |
+
}
|
41 |
+
return res;
|
42 |
+
}
|
packages/inference/src/tasks/nlp/textGeneration.ts
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks/src/tasks/text-generation/inference";
|
2 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
3 |
+
import type { BaseArgs, Options } from "../../types";
|
4 |
+
import { request } from "../custom/request";
|
5 |
+
|
6 |
+
/**
|
7 |
+
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
|
8 |
+
*/
|
9 |
+
export async function textGeneration(
|
10 |
+
args: BaseArgs & TextGenerationInput,
|
11 |
+
options?: Options
|
12 |
+
): Promise<TextGenerationOutput> {
|
13 |
+
const res = await request<TextGenerationOutput[]>(args, {
|
14 |
+
...options,
|
15 |
+
taskHint: "text-generation",
|
16 |
+
});
|
17 |
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");
|
18 |
+
if (!isValidOutput) {
|
19 |
+
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
|
20 |
+
}
|
21 |
+
return res?.[0];
|
22 |
+
}
|
packages/inference/src/tasks/nlp/textGenerationStream.ts
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import type { BaseArgs, Options } from "../../types";
|
2 |
+
import { streamingRequest } from "../custom/streamingRequest";
|
3 |
+
|
4 |
+
import type { TextGenerationInput } from "@huggingface/tasks/src/tasks/text-generation/inference";
|
5 |
+
|
6 |
+
export interface TextGenerationStreamToken {
|
7 |
+
/** Token ID from the model tokenizer */
|
8 |
+
id: number;
|
9 |
+
/** Token text */
|
10 |
+
text: string;
|
11 |
+
/** Logprob */
|
12 |
+
logprob: number;
|
13 |
+
/**
|
14 |
+
* Is the token a special token
|
15 |
+
* Can be used to ignore tokens when concatenating
|
16 |
+
*/
|
17 |
+
special: boolean;
|
18 |
+
}
|
19 |
+
|
20 |
+
export interface TextGenerationStreamPrefillToken {
|
21 |
+
/** Token ID from the model tokenizer */
|
22 |
+
id: number;
|
23 |
+
/** Token text */
|
24 |
+
text: string;
|
25 |
+
/**
|
26 |
+
* Logprob
|
27 |
+
* Optional since the logprob of the first token cannot be computed
|
28 |
+
*/
|
29 |
+
logprob?: number;
|
30 |
+
}
|
31 |
+
|
32 |
+
export interface TextGenerationStreamBestOfSequence {
|
33 |
+
/** Generated text */
|
34 |
+
generated_text: string;
|
35 |
+
/** Generation finish reason */
|
36 |
+
finish_reason: TextGenerationStreamFinishReason;
|
37 |
+
/** Number of generated tokens */
|
38 |
+
generated_tokens: number;
|
39 |
+
/** Sampling seed if sampling was activated */
|
40 |
+
seed?: number;
|
41 |
+
/** Prompt tokens */
|
42 |
+
prefill: TextGenerationStreamPrefillToken[];
|
43 |
+
/** Generated tokens */
|
44 |
+
tokens: TextGenerationStreamToken[];
|
45 |
+
}
|
46 |
+
|
47 |
+
export type TextGenerationStreamFinishReason =
|
48 |
+
/** number of generated tokens == `max_new_tokens` */
|
49 |
+
| "length"
|
50 |
+
/** the model generated its end of sequence token */
|
51 |
+
| "eos_token"
|
52 |
+
/** the model generated a text included in `stop_sequences` */
|
53 |
+
| "stop_sequence";
|
54 |
+
|
55 |
+
export interface TextGenerationStreamDetails {
|
56 |
+
/** Generation finish reason */
|
57 |
+
finish_reason: TextGenerationStreamFinishReason;
|
58 |
+
/** Number of generated tokens */
|
59 |
+
generated_tokens: number;
|
60 |
+
/** Sampling seed if sampling was activated */
|
61 |
+
seed?: number;
|
62 |
+
/** Prompt tokens */
|
63 |
+
prefill: TextGenerationStreamPrefillToken[];
|
64 |
+
/** */
|
65 |
+
tokens: TextGenerationStreamToken[];
|
66 |
+
/** Additional sequences when using the `best_of` parameter */
|
67 |
+
best_of_sequences?: TextGenerationStreamBestOfSequence[];
|
68 |
+
}
|
69 |
+
|
70 |
+
export interface TextGenerationStreamOutput {
|
71 |
+
/** Generated token, one at a time */
|
72 |
+
token: TextGenerationStreamToken;
|
73 |
+
/**
|
74 |
+
* Complete generated text
|
75 |
+
* Only available when the generation is finished
|
76 |
+
*/
|
77 |
+
generated_text: string | null;
|
78 |
+
/**
|
79 |
+
* Generation details
|
80 |
+
* Only available when the generation is finished
|
81 |
+
*/
|
82 |
+
details: TextGenerationStreamDetails | null;
|
83 |
+
}
|
84 |
+
|
85 |
+
/**
|
86 |
+
* Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
|
87 |
+
*/
|
88 |
+
export async function* textGenerationStream(
|
89 |
+
args: BaseArgs & TextGenerationInput,
|
90 |
+
options?: Options
|
91 |
+
): AsyncGenerator<TextGenerationStreamOutput> {
|
92 |
+
yield* streamingRequest<TextGenerationStreamOutput>(args, {
|
93 |
+
...options,
|
94 |
+
taskHint: "text-generation",
|
95 |
+
});
|
96 |
+
}
|
packages/inference/src/tasks/nlp/tokenClassification.ts
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { toArray } from "../../utils/toArray";
|
4 |
+
import { request } from "../custom/request";
|
5 |
+
|
6 |
+
export type TokenClassificationArgs = BaseArgs & {
|
7 |
+
/**
|
8 |
+
* A string to be classified
|
9 |
+
*/
|
10 |
+
inputs: string;
|
11 |
+
parameters?: {
|
12 |
+
/**
|
13 |
+
* (Default: simple). There are several aggregation strategies:
|
14 |
+
*
|
15 |
+
* none: Every token gets classified without further aggregation.
|
16 |
+
*
|
17 |
+
* simple: Entities are grouped according to the default schema (B-, I- tags get merged when the tag is similar).
|
18 |
+
*
|
19 |
+
* first: Same as the simple strategy except words cannot end up with different tags. Words will use the tag of the first token when there is ambiguity.
|
20 |
+
*
|
21 |
+
* average: Same as the simple strategy except words cannot end up with different tags. Scores are averaged across tokens and then the maximum label is applied.
|
22 |
+
*
|
23 |
+
* max: Same as the simple strategy except words cannot end up with different tags. Word entity will be the token with the maximum score.
|
24 |
+
*/
|
25 |
+
aggregation_strategy?: "none" | "simple" | "first" | "average" | "max";
|
26 |
+
};
|
27 |
+
};
|
28 |
+
|
29 |
+
export interface TokenClassificationOutputValue {
|
30 |
+
/**
|
31 |
+
* The offset stringwise where the answer is located. Useful to disambiguate if word occurs multiple times.
|
32 |
+
*/
|
33 |
+
end: number;
|
34 |
+
/**
|
35 |
+
* The type for the entity being recognized (model specific).
|
36 |
+
*/
|
37 |
+
entity_group: string;
|
38 |
+
/**
|
39 |
+
* How likely the entity was recognized.
|
40 |
+
*/
|
41 |
+
score: number;
|
42 |
+
/**
|
43 |
+
* The offset stringwise where the answer is located. Useful to disambiguate if word occurs multiple times.
|
44 |
+
*/
|
45 |
+
start: number;
|
46 |
+
/**
|
47 |
+
* The string that was captured
|
48 |
+
*/
|
49 |
+
word: string;
|
50 |
+
}
|
51 |
+
|
52 |
+
export type TokenClassificationOutput = TokenClassificationOutputValue[];
|
53 |
+
|
54 |
+
/**
|
55 |
+
* Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. Recommended model: dbmdz/bert-large-cased-finetuned-conll03-english
|
56 |
+
*/
|
57 |
+
export async function tokenClassification(
|
58 |
+
args: TokenClassificationArgs,
|
59 |
+
options?: Options
|
60 |
+
): Promise<TokenClassificationOutput> {
|
61 |
+
const res = toArray(
|
62 |
+
await request<TokenClassificationOutput[number] | TokenClassificationOutput>(args, {
|
63 |
+
...options,
|
64 |
+
taskHint: "token-classification",
|
65 |
+
})
|
66 |
+
);
|
67 |
+
const isValidOutput =
|
68 |
+
Array.isArray(res) &&
|
69 |
+
res.every(
|
70 |
+
(x) =>
|
71 |
+
typeof x.end === "number" &&
|
72 |
+
typeof x.entity_group === "string" &&
|
73 |
+
typeof x.score === "number" &&
|
74 |
+
typeof x.start === "number" &&
|
75 |
+
typeof x.word === "string"
|
76 |
+
);
|
77 |
+
if (!isValidOutput) {
|
78 |
+
throw new InferenceOutputError(
|
79 |
+
"Expected Array<{end: number, entity_group: string, score: number, start: number, word: string}>"
|
80 |
+
);
|
81 |
+
}
|
82 |
+
return res;
|
83 |
+
}
|
packages/inference/src/tasks/nlp/translation.ts
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type TranslationArgs = BaseArgs & {
|
6 |
+
/**
|
7 |
+
* A string to be translated
|
8 |
+
*/
|
9 |
+
inputs: string | string[];
|
10 |
+
};
|
11 |
+
|
12 |
+
export interface TranslationOutputValue {
|
13 |
+
/**
|
14 |
+
* The string after translation
|
15 |
+
*/
|
16 |
+
translation_text: string;
|
17 |
+
}
|
18 |
+
|
19 |
+
export type TranslationOutput = TranslationOutputValue | TranslationOutputValue[];
|
20 |
+
|
21 |
+
/**
|
22 |
+
* This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
|
23 |
+
*/
|
24 |
+
export async function translation(args: TranslationArgs, options?: Options): Promise<TranslationOutput> {
|
25 |
+
const res = await request<TranslationOutputValue[]>(args, {
|
26 |
+
...options,
|
27 |
+
taskHint: "translation",
|
28 |
+
});
|
29 |
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.translation_text === "string");
|
30 |
+
if (!isValidOutput) {
|
31 |
+
throw new InferenceOutputError("Expected type Array<{translation_text: string}>");
|
32 |
+
}
|
33 |
+
return res?.length === 1 ? res?.[0] : res;
|
34 |
+
}
|
packages/inference/src/tasks/nlp/zeroShotClassification.ts
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { toArray } from "../../utils/toArray";
|
4 |
+
import { request } from "../custom/request";
|
5 |
+
|
6 |
+
export type ZeroShotClassificationArgs = BaseArgs & {
|
7 |
+
/**
|
8 |
+
* a string or list of strings
|
9 |
+
*/
|
10 |
+
inputs: string | string[];
|
11 |
+
parameters: {
|
12 |
+
/**
|
13 |
+
* a list of strings that are potential classes for inputs. (max 10 candidate_labels, for more, simply run multiple requests, results are going to be misleading if using too many candidate_labels anyway. If you want to keep the exact same, you can simply run multi_label=True and do the scaling on your end.
|
14 |
+
*/
|
15 |
+
candidate_labels: string[];
|
16 |
+
/**
|
17 |
+
* (Default: false) Boolean that is set to True if classes can overlap
|
18 |
+
*/
|
19 |
+
multi_label?: boolean;
|
20 |
+
};
|
21 |
+
};
|
22 |
+
|
23 |
+
export interface ZeroShotClassificationOutputValue {
|
24 |
+
labels: string[];
|
25 |
+
scores: number[];
|
26 |
+
sequence: string;
|
27 |
+
}
|
28 |
+
|
29 |
+
export type ZeroShotClassificationOutput = ZeroShotClassificationOutputValue[];
|
30 |
+
|
31 |
+
/**
|
32 |
+
* This task is super useful to try out classification with zero code, you simply pass a sentence/paragraph and the possible labels for that sentence, and you get a result. Recommended model: facebook/bart-large-mnli.
|
33 |
+
*/
|
34 |
+
export async function zeroShotClassification(
|
35 |
+
args: ZeroShotClassificationArgs,
|
36 |
+
options?: Options
|
37 |
+
): Promise<ZeroShotClassificationOutput> {
|
38 |
+
const res = toArray(
|
39 |
+
await request<ZeroShotClassificationOutput[number] | ZeroShotClassificationOutput>(args, {
|
40 |
+
...options,
|
41 |
+
taskHint: "zero-shot-classification",
|
42 |
+
})
|
43 |
+
);
|
44 |
+
const isValidOutput =
|
45 |
+
Array.isArray(res) &&
|
46 |
+
res.every(
|
47 |
+
(x) =>
|
48 |
+
Array.isArray(x.labels) &&
|
49 |
+
x.labels.every((_label) => typeof _label === "string") &&
|
50 |
+
Array.isArray(x.scores) &&
|
51 |
+
x.scores.every((_score) => typeof _score === "number") &&
|
52 |
+
typeof x.sequence === "string"
|
53 |
+
);
|
54 |
+
if (!isValidOutput) {
|
55 |
+
throw new InferenceOutputError("Expected Array<{labels: string[], scores: number[], sequence: string}>");
|
56 |
+
}
|
57 |
+
return res;
|
58 |
+
}
|
packages/inference/src/tasks/tabular/tabularClassification.ts
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type TabularClassificationArgs = BaseArgs & {
|
6 |
+
inputs: {
|
7 |
+
/**
|
8 |
+
* A table of data represented as a dict of list where entries are headers and the lists are all the values, all lists must have the same size.
|
9 |
+
*/
|
10 |
+
data: Record<string, string[]>;
|
11 |
+
};
|
12 |
+
};
|
13 |
+
|
14 |
+
/**
|
15 |
+
* A list of predicted labels for each row
|
16 |
+
*/
|
17 |
+
export type TabularClassificationOutput = number[];
|
18 |
+
|
19 |
+
/**
|
20 |
+
* Predicts target label for a given set of features in tabular form.
|
21 |
+
* Typically, you will want to train a classification model on your training data and use it with your new data of the same format.
|
22 |
+
* Example model: vvmnnnkv/wine-quality
|
23 |
+
*/
|
24 |
+
export async function tabularClassification(
|
25 |
+
args: TabularClassificationArgs,
|
26 |
+
options?: Options
|
27 |
+
): Promise<TabularClassificationOutput> {
|
28 |
+
const res = await request<TabularClassificationOutput>(args, {
|
29 |
+
...options,
|
30 |
+
taskHint: "tabular-classification",
|
31 |
+
});
|
32 |
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
|
33 |
+
if (!isValidOutput) {
|
34 |
+
throw new InferenceOutputError("Expected number[]");
|
35 |
+
}
|
36 |
+
return res;
|
37 |
+
}
|
packages/inference/src/tasks/tabular/tabularRegression.ts
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { InferenceOutputError } from "../../lib/InferenceOutputError";
|
2 |
+
import type { BaseArgs, Options } from "../../types";
|
3 |
+
import { request } from "../custom/request";
|
4 |
+
|
5 |
+
export type TabularRegressionArgs = BaseArgs & {
|
6 |
+
inputs: {
|
7 |
+
/**
|
8 |
+
* A table of data represented as a dict of list where entries are headers and the lists are all the values, all lists must have the same size.
|
9 |
+
*/
|
10 |
+
data: Record<string, string[]>;
|
11 |
+
};
|
12 |
+
};
|
13 |
+
|
14 |
+
/**
|
15 |
+
* a list of predicted values for each row
|
16 |
+
*/
|
17 |
+
export type TabularRegressionOutput = number[];
|
18 |
+
|
19 |
+
/**
|
20 |
+
* Predicts target value for a given set of features in tabular form.
|
21 |
+
* Typically, you will want to train a regression model on your training data and use it with your new data of the same format.
|
22 |
+
* Example model: scikit-learn/Fish-Weight
|
23 |
+
*/
|
24 |
+
export async function tabularRegression(
|
25 |
+
args: TabularRegressionArgs,
|
26 |
+
options?: Options
|
27 |
+
): Promise<TabularRegressionOutput> {
|
28 |
+
const res = await request<TabularRegressionOutput>(args, {
|
29 |
+
...options,
|
30 |
+
taskHint: "tabular-regression",
|
31 |
+
});
|
32 |
+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
|
33 |
+
if (!isValidOutput) {
|
34 |
+
throw new InferenceOutputError("Expected number[]");
|
35 |
+
}
|
36 |
+
return res;
|
37 |
+
}
|
packages/inference/src/types.ts
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import type { PipelineType } from "@huggingface/tasks";
|
2 |
+
|
3 |
+
export interface Options {
|
4 |
+
/**
|
5 |
+
* (Default: true) Boolean. If a request 503s and wait_for_model is set to false, the request will be retried with the same parameters but with wait_for_model set to true.
|
6 |
+
*/
|
7 |
+
retry_on_error?: boolean;
|
8 |
+
/**
|
9 |
+
* (Default: true). Boolean. There is a cache layer on Inference API (serverless) to speedup requests we have already seen. Most models can use those results as is as models are deterministic (meaning the results will be the same anyway). However if you use a non deterministic model, you can set this parameter to prevent the caching mechanism from being used resulting in a real new query.
|
10 |
+
*/
|
11 |
+
use_cache?: boolean;
|
12 |
+
/**
|
13 |
+
* (Default: false). Boolean. Do not load the model if it's not already available.
|
14 |
+
*/
|
15 |
+
dont_load_model?: boolean;
|
16 |
+
/**
|
17 |
+
* (Default: false). Boolean to use GPU instead of CPU for inference (requires Startup plan at least).
|
18 |
+
*/
|
19 |
+
use_gpu?: boolean;
|
20 |
+
|
21 |
+
/**
|
22 |
+
* (Default: false) Boolean. If the model is not ready, wait for it instead of receiving 503. It limits the number of requests required to get your inference done. It is advised to only set this flag to true after receiving a 503 error as it will limit hanging in your application to known places.
|
23 |
+
*/
|
24 |
+
wait_for_model?: boolean;
|
25 |
+
/**
|
26 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
27 |
+
*/
|
28 |
+
fetch?: typeof fetch;
|
29 |
+
/**
|
30 |
+
* Abort Controller signal to use for request interruption.
|
31 |
+
*/
|
32 |
+
signal?: AbortSignal;
|
33 |
+
|
34 |
+
/**
|
35 |
+
* (Default: "same-origin"). String | Boolean. Credentials to use for the request. If this is a string, it will be passed straight on. If it's a boolean, true will be "include" and false will not send credentials at all.
|
36 |
+
*/
|
37 |
+
includeCredentials?: string | boolean;
|
38 |
+
}
|
39 |
+
|
40 |
+
export type InferenceTask = Exclude<PipelineType, "other">;
|
41 |
+
|
42 |
+
export interface BaseArgs {
|
43 |
+
/**
|
44 |
+
* The access token to use. Without it, you'll get rate-limited quickly.
|
45 |
+
*
|
46 |
+
* Can be created for free in hf.co/settings/token
|
47 |
+
*/
|
48 |
+
accessToken?: string;
|
49 |
+
/**
|
50 |
+
* The model to use. Can be a full URL for a dedicated inference endpoint.
|
51 |
+
*
|
52 |
+
* If not specified, will call huggingface.co/api/tasks to get the default model for the task.
|
53 |
+
*/
|
54 |
+
model?: string;
|
55 |
+
}
|
56 |
+
|
57 |
+
export type RequestArgs = BaseArgs &
|
58 |
+
({ data: Blob | ArrayBuffer } | { inputs: unknown }) & {
|
59 |
+
parameters?: Record<string, unknown>;
|
60 |
+
accessToken?: string;
|
61 |
+
};
|
packages/inference/src/utils/distributive-omit.d.ts
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// https://dev.to/safareli/pick-omit-and-union-types-in-typescript-4nd9
|
2 |
+
// https://github.com/microsoft/TypeScript/issues/28339#issuecomment-467393437
|
3 |
+
/**
|
4 |
+
* This allows omitting keys from objects inside unions, without merging the individual components of the union.
|
5 |
+
*/
|
6 |
+
|
7 |
+
type Keys<T> = keyof T;
|
8 |
+
type DistributiveKeys<T> = T extends unknown ? Keys<T> : never;
|
9 |
+
type Omit_<T, K> = Omit<T, Extract<keyof T, K>>;
|
10 |
+
|
11 |
+
export type DistributiveOmit<T, K> = T extends unknown
|
12 |
+
? keyof Omit_<T, K> extends never
|
13 |
+
? never
|
14 |
+
: { [P in keyof Omit_<T, K>]: Omit_<T, K>[P] }
|
15 |
+
: never;
|
packages/inference/src/utils/omit.ts
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { pick } from "./pick";
|
2 |
+
import { typedInclude } from "./typedInclude";
|
3 |
+
|
4 |
+
/**
|
5 |
+
* Return copy of object, omitting blocklisted array of props
|
6 |
+
*/
|
7 |
+
export function omit<T extends object, K extends keyof T>(o: T, props: K[] | K): Pick<T, Exclude<keyof T, K>> {
|
8 |
+
const propsArr = Array.isArray(props) ? props : [props];
|
9 |
+
const letsKeep = (Object.keys(o) as (keyof T)[]).filter((prop) => !typedInclude(propsArr, prop));
|
10 |
+
return pick(o, letsKeep);
|
11 |
+
}
|
packages/inference/src/utils/pick.ts
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* Return copy of object, only keeping allowlisted properties.
|
3 |
+
*/
|
4 |
+
export function pick<T, K extends keyof T>(o: T, props: K[] | ReadonlyArray<K>): Pick<T, K> {
|
5 |
+
return Object.assign(
|
6 |
+
{},
|
7 |
+
...props.map((prop) => {
|
8 |
+
if (o[prop] !== undefined) {
|
9 |
+
return { [prop]: o[prop] };
|
10 |
+
}
|
11 |
+
})
|
12 |
+
);
|
13 |
+
}
|
packages/inference/src/utils/toArray.ts
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export function toArray<T>(obj: T): T extends unknown[] ? T : T[] {
|
2 |
+
if (Array.isArray(obj)) {
|
3 |
+
return obj as T extends unknown[] ? T : T[];
|
4 |
+
}
|
5 |
+
return [obj] as T extends unknown[] ? T : T[];
|
6 |
+
}
|
packages/inference/src/utils/typedInclude.ts
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
export function typedInclude<V, T extends V>(arr: readonly T[], v: V): v is T {
|
2 |
+
return arr.includes(v as T);
|
3 |
+
}
|
packages/inference/src/vendor/fetch-event-source/parse.spec.ts
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { expect, it, describe } from "vitest";
|
2 |
+
const fail = (msg: string) => { throw new Error(msg) };
|
3 |
+
|
4 |
+
/**
|
5 |
+
This file is a part of fetch-event-source package (as of v2.0.1)
|
6 |
+
https://github.com/Azure/fetch-event-source/blob/v2.0.1/src/parse.spec.ts
|
7 |
+
|
8 |
+
Full package can be used after it is made compatible with nodejs:
|
9 |
+
https://github.com/Azure/fetch-event-source/issues/20
|
10 |
+
|
11 |
+
Below is the fetch-event-source package license:
|
12 |
+
|
13 |
+
MIT License
|
14 |
+
|
15 |
+
Copyright (c) Microsoft Corporation.
|
16 |
+
|
17 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
18 |
+
of this software and associated documentation files (the "Software"), to deal
|
19 |
+
in the Software without restriction, including without limitation the rights
|
20 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
21 |
+
copies of the Software, and to permit persons to whom the Software is
|
22 |
+
furnished to do so, subject to the following conditions:
|
23 |
+
|
24 |
+
The above copyright notice and this permission notice shall be included in all
|
25 |
+
copies or substantial portions of the Software.
|
26 |
+
|
27 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
28 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
29 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
30 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
31 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
32 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
33 |
+
SOFTWARE
|
34 |
+
|
35 |
+
*/
|
36 |
+
|
37 |
+
import * as parse from './parse';
|
38 |
+
|
39 |
+
describe('parse', () => {
|
40 |
+
const encoder = new TextEncoder();
|
41 |
+
const decoder = new TextDecoder();
|
42 |
+
|
43 |
+
describe('getLines', () => {
|
44 |
+
it('single line', () => {
|
45 |
+
// arrange:
|
46 |
+
let lineNum = 0;
|
47 |
+
const next = parse.getLines((line, fieldLength) => {
|
48 |
+
++lineNum;
|
49 |
+
expect(decoder.decode(line)).toEqual('id: abc');
|
50 |
+
expect(fieldLength).toEqual(2);
|
51 |
+
});
|
52 |
+
|
53 |
+
// act:
|
54 |
+
next(encoder.encode('id: abc\n'));
|
55 |
+
|
56 |
+
// assert:
|
57 |
+
expect(lineNum).toBe(1);
|
58 |
+
});
|
59 |
+
|
60 |
+
it('multiple lines', () => {
|
61 |
+
// arrange:
|
62 |
+
let lineNum = 0;
|
63 |
+
const next = parse.getLines((line, fieldLength) => {
|
64 |
+
++lineNum;
|
65 |
+
expect(decoder.decode(line)).toEqual(lineNum === 1 ? 'id: abc' : 'data: def');
|
66 |
+
expect(fieldLength).toEqual(lineNum === 1 ? 2 : 4);
|
67 |
+
});
|
68 |
+
|
69 |
+
// act:
|
70 |
+
next(encoder.encode('id: abc\n'));
|
71 |
+
next(encoder.encode('data: def\n'));
|
72 |
+
|
73 |
+
// assert:
|
74 |
+
expect(lineNum).toBe(2);
|
75 |
+
});
|
76 |
+
|
77 |
+
it('single line split across multiple arrays', () => {
|
78 |
+
// arrange:
|
79 |
+
let lineNum = 0;
|
80 |
+
const next = parse.getLines((line, fieldLength) => {
|
81 |
+
++lineNum;
|
82 |
+
expect(decoder.decode(line)).toEqual('id: abc');
|
83 |
+
expect(fieldLength).toEqual(2);
|
84 |
+
});
|
85 |
+
|
86 |
+
// act:
|
87 |
+
next(encoder.encode('id: a'));
|
88 |
+
next(encoder.encode('bc\n'));
|
89 |
+
|
90 |
+
// assert:
|
91 |
+
expect(lineNum).toBe(1);
|
92 |
+
});
|
93 |
+
|
94 |
+
it('multiple lines split across multiple arrays', () => {
|
95 |
+
// arrange:
|
96 |
+
let lineNum = 0;
|
97 |
+
const next = parse.getLines((line, fieldLength) => {
|
98 |
+
++lineNum;
|
99 |
+
expect(decoder.decode(line)).toEqual(lineNum === 1 ? 'id: abc' : 'data: def');
|
100 |
+
expect(fieldLength).toEqual(lineNum === 1 ? 2 : 4);
|
101 |
+
});
|
102 |
+
|
103 |
+
// act:
|
104 |
+
next(encoder.encode('id: ab'));
|
105 |
+
next(encoder.encode('c\nda'));
|
106 |
+
next(encoder.encode('ta: def\n'));
|
107 |
+
|
108 |
+
// assert:
|
109 |
+
expect(lineNum).toBe(2);
|
110 |
+
});
|
111 |
+
|
112 |
+
it('new line', () => {
|
113 |
+
// arrange:
|
114 |
+
let lineNum = 0;
|
115 |
+
const next = parse.getLines((line, fieldLength) => {
|
116 |
+
++lineNum;
|
117 |
+
expect(decoder.decode(line)).toEqual('');
|
118 |
+
expect(fieldLength).toEqual(-1);
|
119 |
+
});
|
120 |
+
|
121 |
+
// act:
|
122 |
+
next(encoder.encode('\n'));
|
123 |
+
|
124 |
+
// assert:
|
125 |
+
expect(lineNum).toBe(1);
|
126 |
+
});
|
127 |
+
|
128 |
+
it('comment line', () => {
|
129 |
+
// arrange:
|
130 |
+
let lineNum = 0;
|
131 |
+
const next = parse.getLines((line, fieldLength) => {
|
132 |
+
++lineNum;
|
133 |
+
expect(decoder.decode(line)).toEqual(': this is a comment');
|
134 |
+
expect(fieldLength).toEqual(0);
|
135 |
+
});
|
136 |
+
|
137 |
+
// act:
|
138 |
+
next(encoder.encode(': this is a comment\n'));
|
139 |
+
|
140 |
+
// assert:
|
141 |
+
expect(lineNum).toBe(1);
|
142 |
+
});
|
143 |
+
|
144 |
+
it('line with no field', () => {
|
145 |
+
// arrange:
|
146 |
+
let lineNum = 0;
|
147 |
+
const next = parse.getLines((line, fieldLength) => {
|
148 |
+
++lineNum;
|
149 |
+
expect(decoder.decode(line)).toEqual('this is an invalid line');
|
150 |
+
expect(fieldLength).toEqual(-1);
|
151 |
+
});
|
152 |
+
|
153 |
+
// act:
|
154 |
+
next(encoder.encode('this is an invalid line\n'));
|
155 |
+
|
156 |
+
// assert:
|
157 |
+
expect(lineNum).toBe(1);
|
158 |
+
});
|
159 |
+
|
160 |
+
it('line with multiple colons', () => {
|
161 |
+
// arrange:
|
162 |
+
let lineNum = 0;
|
163 |
+
const next = parse.getLines((line, fieldLength) => {
|
164 |
+
++lineNum;
|
165 |
+
expect(decoder.decode(line)).toEqual('id: abc: def');
|
166 |
+
expect(fieldLength).toEqual(2);
|
167 |
+
});
|
168 |
+
|
169 |
+
// act:
|
170 |
+
next(encoder.encode('id: abc: def\n'));
|
171 |
+
|
172 |
+
// assert:
|
173 |
+
expect(lineNum).toBe(1);
|
174 |
+
});
|
175 |
+
|
176 |
+
it('single byte array with multiple lines separated by \\n', () => {
|
177 |
+
// arrange:
|
178 |
+
let lineNum = 0;
|
179 |
+
const next = parse.getLines((line, fieldLength) => {
|
180 |
+
++lineNum;
|
181 |
+
expect(decoder.decode(line)).toEqual(lineNum === 1 ? 'id: abc' : 'data: def');
|
182 |
+
expect(fieldLength).toEqual(lineNum === 1 ? 2 : 4);
|
183 |
+
});
|
184 |
+
|
185 |
+
// act:
|
186 |
+
next(encoder.encode('id: abc\ndata: def\n'));
|
187 |
+
|
188 |
+
// assert:
|
189 |
+
expect(lineNum).toBe(2);
|
190 |
+
});
|
191 |
+
|
192 |
+
it('single byte array with multiple lines separated by \\r', () => {
|
193 |
+
// arrange:
|
194 |
+
let lineNum = 0;
|
195 |
+
const next = parse.getLines((line, fieldLength) => {
|
196 |
+
++lineNum;
|
197 |
+
expect(decoder.decode(line)).toEqual(lineNum === 1 ? 'id: abc' : 'data: def');
|
198 |
+
expect(fieldLength).toEqual(lineNum === 1 ? 2 : 4);
|
199 |
+
});
|
200 |
+
|
201 |
+
// act:
|
202 |
+
next(encoder.encode('id: abc\rdata: def\r'));
|
203 |
+
|
204 |
+
// assert:
|
205 |
+
expect(lineNum).toBe(2);
|
206 |
+
});
|
207 |
+
|
208 |
+
it('single byte array with multiple lines separated by \\r\\n', () => {
|
209 |
+
// arrange:
|
210 |
+
let lineNum = 0;
|
211 |
+
const next = parse.getLines((line, fieldLength) => {
|
212 |
+
++lineNum;
|
213 |
+
expect(decoder.decode(line)).toEqual(lineNum === 1 ? 'id: abc' : 'data: def');
|
214 |
+
expect(fieldLength).toEqual(lineNum === 1 ? 2 : 4);
|
215 |
+
});
|
216 |
+
|
217 |
+
// act:
|
218 |
+
next(encoder.encode('id: abc\r\ndata: def\r\n'));
|
219 |
+
|
220 |
+
// assert:
|
221 |
+
expect(lineNum).toBe(2);
|
222 |
+
});
|
223 |
+
});
|
224 |
+
|
225 |
+
describe('getMessages', () => {
|
226 |
+
it('happy path', () => {
|
227 |
+
// arrange:
|
228 |
+
let msgNum = 0;
|
229 |
+
const next = parse.getMessages(id => {
|
230 |
+
expect(id).toEqual('abc');
|
231 |
+
}, retry => {
|
232 |
+
expect(retry).toEqual(42);
|
233 |
+
}, msg => {
|
234 |
+
++msgNum;
|
235 |
+
expect(msg).toEqual({
|
236 |
+
retry: 42,
|
237 |
+
id: 'abc',
|
238 |
+
event: 'def',
|
239 |
+
data: 'ghi'
|
240 |
+
});
|
241 |
+
});
|
242 |
+
|
243 |
+
// act:
|
244 |
+
next(encoder.encode('retry: 42'), 5);
|
245 |
+
next(encoder.encode('id: abc'), 2);
|
246 |
+
next(encoder.encode('event:def'), 5);
|
247 |
+
next(encoder.encode('data:ghi'), 4);
|
248 |
+
next(encoder.encode(''), -1);
|
249 |
+
|
250 |
+
// assert:
|
251 |
+
expect(msgNum).toBe(1);
|
252 |
+
});
|
253 |
+
|
254 |
+
it('skip unknown fields', () => {
|
255 |
+
let msgNum = 0;
|
256 |
+
const next = parse.getMessages(id => {
|
257 |
+
expect(id).toEqual('abc');
|
258 |
+
}, _retry => {
|
259 |
+
fail('retry should not be called');
|
260 |
+
}, msg => {
|
261 |
+
++msgNum;
|
262 |
+
expect(msg).toEqual({
|
263 |
+
id: 'abc',
|
264 |
+
data: '',
|
265 |
+
event: '',
|
266 |
+
retry: undefined,
|
267 |
+
});
|
268 |
+
});
|
269 |
+
|
270 |
+
// act:
|
271 |
+
next(encoder.encode('id: abc'), 2);
|
272 |
+
next(encoder.encode('foo: null'), 3);
|
273 |
+
next(encoder.encode(''), -1);
|
274 |
+
|
275 |
+
// assert:
|
276 |
+
expect(msgNum).toBe(1);
|
277 |
+
});
|
278 |
+
|
279 |
+
it('ignore non-integer retry', () => {
|
280 |
+
let msgNum = 0;
|
281 |
+
const next = parse.getMessages(_id => {
|
282 |
+
fail('id should not be called');
|
283 |
+
}, _retry => {
|
284 |
+
fail('retry should not be called');
|
285 |
+
}, msg => {
|
286 |
+
++msgNum;
|
287 |
+
expect(msg).toEqual({
|
288 |
+
id: '',
|
289 |
+
data: '',
|
290 |
+
event: '',
|
291 |
+
retry: undefined,
|
292 |
+
});
|
293 |
+
});
|
294 |
+
|
295 |
+
// act:
|
296 |
+
next(encoder.encode('retry: def'), 5);
|
297 |
+
next(encoder.encode(''), -1);
|
298 |
+
|
299 |
+
// assert:
|
300 |
+
expect(msgNum).toBe(1);
|
301 |
+
});
|
302 |
+
|
303 |
+
it('skip comment-only messages', () => {
|
304 |
+
// arrange:
|
305 |
+
let msgNum = 0;
|
306 |
+
const next = parse.getMessages(id => {
|
307 |
+
expect(id).toEqual('123');
|
308 |
+
}, _retry => {
|
309 |
+
fail('retry should not be called');
|
310 |
+
}, msg => {
|
311 |
+
++msgNum;
|
312 |
+
expect(msg).toEqual({
|
313 |
+
retry: undefined,
|
314 |
+
id: '123',
|
315 |
+
event: 'foo ',
|
316 |
+
data: '',
|
317 |
+
});
|
318 |
+
});
|
319 |
+
|
320 |
+
// act:
|
321 |
+
next(encoder.encode('id:123'), 2);
|
322 |
+
next(encoder.encode(':'), 0);
|
323 |
+
next(encoder.encode(': '), 0);
|
324 |
+
next(encoder.encode('event: foo '), 5);
|
325 |
+
next(encoder.encode(''), -1);
|
326 |
+
|
327 |
+
// assert:
|
328 |
+
expect(msgNum).toBe(1);
|
329 |
+
});
|
330 |
+
|
331 |
+
it('should append data split across multiple lines', () => {
|
332 |
+
// arrange:
|
333 |
+
let msgNum = 0;
|
334 |
+
const next = parse.getMessages(_id => {
|
335 |
+
fail('id should not be called');
|
336 |
+
}, _retry => {
|
337 |
+
fail('retry should not be called');
|
338 |
+
}, msg => {
|
339 |
+
++msgNum;
|
340 |
+
expect(msg).toEqual({
|
341 |
+
data: 'YHOO\n+2\n\n10',
|
342 |
+
id: '',
|
343 |
+
event: '',
|
344 |
+
retry: undefined,
|
345 |
+
});
|
346 |
+
});
|
347 |
+
|
348 |
+
// act:
|
349 |
+
next(encoder.encode('data:YHOO'), 4);
|
350 |
+
next(encoder.encode('data: +2'), 4);
|
351 |
+
next(encoder.encode('data'), 4);
|
352 |
+
next(encoder.encode('data: 10'), 4);
|
353 |
+
next(encoder.encode(''), -1);
|
354 |
+
|
355 |
+
// assert:
|
356 |
+
expect(msgNum).toBe(1);
|
357 |
+
});
|
358 |
+
|
359 |
+
it('should reset id if sent multiple times', () => {
|
360 |
+
// arrange:
|
361 |
+
const expectedIds = ['foo', ''];
|
362 |
+
let idsIdx = 0;
|
363 |
+
let msgNum = 0;
|
364 |
+
const next = parse.getMessages(id => {
|
365 |
+
expect(id).toEqual(expectedIds[idsIdx]);
|
366 |
+
++idsIdx;
|
367 |
+
}, _retry => {
|
368 |
+
fail('retry should not be called');
|
369 |
+
}, msg => {
|
370 |
+
++msgNum;
|
371 |
+
expect(msg).toEqual({
|
372 |
+
data: '',
|
373 |
+
id: '',
|
374 |
+
event: '',
|
375 |
+
retry: undefined,
|
376 |
+
});
|
377 |
+
});
|
378 |
+
|
379 |
+
// act:
|
380 |
+
next(encoder.encode('id: foo'), 2);
|
381 |
+
next(encoder.encode('id'), 2);
|
382 |
+
next(encoder.encode(''), -1);
|
383 |
+
|
384 |
+
// assert:
|
385 |
+
expect(idsIdx).toBe(2);
|
386 |
+
expect(msgNum).toBe(1);
|
387 |
+
});
|
388 |
+
});
|
389 |
+
});
|