Commit
•
eca067e
1
Parent(s):
bf38c53
model
Browse files- 1_Pooling/config.json +7 -0
- README.md +77 -0
- __pycache__/handler.cpython-39.pyc +0 -0
- config.json +24 -0
- config_sentence_transformers.json +7 -0
- handler.py +27 -0
- model_head.pkl +3 -0
- modules.json +20 -0
- pytorch_model.bin +3 -0
- requirements.txt +0 -0
- sentence_bert_config.json +4 -0
- special_tokens_map.json +15 -0
- tokenizer.json +0 -0
- tokenizer_config.json +16 -0
- vocab.txt +0 -0
1_Pooling/config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"word_embedding_dimension": 768,
|
3 |
+
"pooling_mode_cls_token": false,
|
4 |
+
"pooling_mode_mean_tokens": true,
|
5 |
+
"pooling_mode_max_tokens": false,
|
6 |
+
"pooling_mode_mean_sqrt_len_tokens": false
|
7 |
+
}
|
README.md
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
tags:
|
4 |
+
- setfit
|
5 |
+
- endpoints-template
|
6 |
+
- text-classification
|
7 |
+
---
|
8 |
+
|
9 |
+
# SetFit AG News
|
10 |
+
|
11 |
+
This is a [SetFit](https://github.com/huggingface/setfit/tree/main) classifier fine-tuned on the [AG News](https://huggingface.co/datasets/ag_news) dataset.
|
12 |
+
The model was created following the [Outperform OpenAI GPT-3 with SetFit for text-classifiation](https://www.philschmid.de/getting-started-setfit) blog post of [Philipp Schmid](https://www.linkedin.com/in/philipp-schmid-a6a2bb196/).
|
13 |
+
|
14 |
+
The model achieves an accuracy of 0.87 on the test set and was only trained with `32` total examples (8 per class).
|
15 |
+
|
16 |
+
|
17 |
+
```bash
|
18 |
+
***** Running evaluation *****
|
19 |
+
model used: sentence-transformers/all-mpnet-base-v2
|
20 |
+
train dataset: 32 samples
|
21 |
+
accuracy: 0.8731578947368421
|
22 |
+
```
|
23 |
+
|
24 |
+
#### What is SetFit?
|
25 |
+
|
26 |
+
"SetFit" (https://arxiv.org/abs/2209.11055) is a new approach that can be used to create high accuracte text-classification models with limited labeled data. SetFit is outperforming GPT-3 in 7 out of 11 tasks, while being 1600x smaller.
|
27 |
+
Check out the blog to learn more: [Outperform OpenAI GPT-3 with SetFit for text-classifiation](https://www.philschmid.de/getting-started-setfit)
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
# Inference Endpoints
|
32 |
+
|
33 |
+
The model repository also implements a generic custom `handler.py` as an example for how to use `SetFit` models with [inference-endpoints](https://hf.co/inference-endpoints).
|
34 |
+
|
35 |
+
Code: https://huggingface.co/philschmid/setfit-ag-news-endpoint/blob/main/handler.py
|
36 |
+
|
37 |
+
![result](res.png)
|
38 |
+
|
39 |
+
## Send requests with Pyton
|
40 |
+
|
41 |
+
We are going to use requests to send our requests. (make your you have it installed `pip install requests`)
|
42 |
+
|
43 |
+
|
44 |
+
```python
|
45 |
+
import json
|
46 |
+
import requests as r
|
47 |
+
|
48 |
+
ENDPOINT_URL=""# url of your endpoint
|
49 |
+
HF_TOKEN=""
|
50 |
+
|
51 |
+
# payload samples
|
52 |
+
regular_payload = { "inputs": "The New Customers Are In Town Today's customers are increasingly demanding, in Asia as elsewhere in the world. Henry Astorga describes the complex reality faced by today's marketers, which includes much higher expectations than we have been used to. Today's customers want performance, and they want it now!"}
|
53 |
+
|
54 |
+
# HTTP headers for authorization
|
55 |
+
headers= {
|
56 |
+
"Authorization": f"Bearer {HF_TOKEN}",
|
57 |
+
"Content-Type": "application/json"
|
58 |
+
}
|
59 |
+
|
60 |
+
# send request
|
61 |
+
response = r.post(ENDPOINT_URL, headers=headers, json=paramter_payload)
|
62 |
+
classified = response.json()
|
63 |
+
|
64 |
+
print(classified)
|
65 |
+
```
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
**curl example**
|
70 |
+
|
71 |
+
```bash
|
72 |
+
curl https://ak7gduay2ypyr9vp.us-east-1.aws.endpoints.huggingface.cloud \
|
73 |
+
-X POST \
|
74 |
+
--data-binary 'sample.png' \
|
75 |
+
-H "Authorization: Bearer XXX" \
|
76 |
+
-H "Content-Type: null"
|
77 |
+
```
|
__pycache__/handler.cpython-39.pyc
ADDED
Binary file (1.37 kB). View file
|
|
config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/home/ubuntu/.cache/torch/sentence_transformers/sentence-transformers_all-mpnet-base-v2/",
|
3 |
+
"architectures": [
|
4 |
+
"MPNetModel"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"eos_token_id": 2,
|
9 |
+
"hidden_act": "gelu",
|
10 |
+
"hidden_dropout_prob": 0.1,
|
11 |
+
"hidden_size": 768,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 3072,
|
14 |
+
"layer_norm_eps": 1e-05,
|
15 |
+
"max_position_embeddings": 514,
|
16 |
+
"model_type": "mpnet",
|
17 |
+
"num_attention_heads": 12,
|
18 |
+
"num_hidden_layers": 12,
|
19 |
+
"pad_token_id": 1,
|
20 |
+
"relative_attention_num_buckets": 32,
|
21 |
+
"torch_dtype": "float32",
|
22 |
+
"transformers_version": "4.23.1",
|
23 |
+
"vocab_size": 30527
|
24 |
+
}
|
config_sentence_transformers.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"__version__": {
|
3 |
+
"sentence_transformers": "2.0.0",
|
4 |
+
"transformers": "4.6.1",
|
5 |
+
"pytorch": "1.8.1"
|
6 |
+
}
|
7 |
+
}
|
handler.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
from setfit import SetFitModel
|
3 |
+
|
4 |
+
|
5 |
+
class EndpointHandler:
|
6 |
+
def __init__(self, path=""):
|
7 |
+
# load model
|
8 |
+
self.model = SetFitModel.from_pretrained(path)
|
9 |
+
# ag_news id to label mapping
|
10 |
+
self.id2label = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
|
11 |
+
|
12 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
13 |
+
"""
|
14 |
+
data args:
|
15 |
+
inputs (:obj: `str`)
|
16 |
+
Return:
|
17 |
+
A :obj:`list` | `dict`: will be serialized and returned
|
18 |
+
"""
|
19 |
+
# get inputs
|
20 |
+
inputs = data.pop("inputs", data)
|
21 |
+
if isinstance(inputs, str):
|
22 |
+
inputs = [inputs]
|
23 |
+
|
24 |
+
# run normal prediction
|
25 |
+
scores = self.model.predict_proba(inputs)[0]
|
26 |
+
|
27 |
+
return [{"label": self.id2label[i], "score": score.item()} for i, score in enumerate(scores)]
|
model_head.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:70d5c2cb3d62b4ed52b0fceee866879cb89f76a6fb9541c1870786037eb150aa
|
3 |
+
size 25399
|
modules.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "sentence_transformers.models.Transformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Pooling",
|
12 |
+
"type": "sentence_transformers.models.Pooling"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"idx": 2,
|
16 |
+
"name": "2",
|
17 |
+
"path": "2_Normalize",
|
18 |
+
"type": "sentence_transformers.models.Normalize"
|
19 |
+
}
|
20 |
+
]
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:393137b7f36cac72e781dc14923ecf49a05927805b148d519dd4bab662ed5b4d
|
3 |
+
size 438014769
|
requirements.txt
ADDED
File without changes
|
sentence_bert_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_seq_length": 384,
|
3 |
+
"do_lower_case": false
|
4 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<s>",
|
3 |
+
"cls_token": "<s>",
|
4 |
+
"eos_token": "</s>",
|
5 |
+
"mask_token": {
|
6 |
+
"content": "<mask>",
|
7 |
+
"lstrip": true,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false
|
11 |
+
},
|
12 |
+
"pad_token": "<pad>",
|
13 |
+
"sep_token": "</s>",
|
14 |
+
"unk_token": "[UNK]"
|
15 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<s>",
|
3 |
+
"cls_token": "<s>",
|
4 |
+
"do_lower_case": true,
|
5 |
+
"eos_token": "</s>",
|
6 |
+
"mask_token": "<mask>",
|
7 |
+
"model_max_length": 512,
|
8 |
+
"name_or_path": "/home/ubuntu/.cache/torch/sentence_transformers/sentence-transformers_all-mpnet-base-v2/",
|
9 |
+
"pad_token": "<pad>",
|
10 |
+
"sep_token": "</s>",
|
11 |
+
"special_tokens_map_file": null,
|
12 |
+
"strip_accents": null,
|
13 |
+
"tokenize_chinese_chars": true,
|
14 |
+
"tokenizer_class": "MPNetTokenizer",
|
15 |
+
"unk_token": "[UNK]"
|
16 |
+
}
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|