Text Classification
Transformers
PyTorch
Arabic
English
distilbert
chemistry
biology
finance
legal
music
code
art
climate
medical
emotion
endpoints-template
Inference Endpoints
PetraAI commited on
Commit
3427c7d
1 Parent(s): 2c2e361

Upload 9 files

Browse files
.gitattributes CHANGED
@@ -2,13 +2,11 @@
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
  *.npy filter=lfs diff=lfs merge=lfs -text
@@ -22,10 +20,8 @@
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
@@ -33,15 +29,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- codellama-7b-instruct.Q2_K.gguf filter=lfs diff=lfs merge=lfs -text
37
- codellama-7b-instruct.Q3_K_S.gguf filter=lfs diff=lfs merge=lfs -text
38
- codellama-7b-instruct.Q3_K_M.gguf filter=lfs diff=lfs merge=lfs -text
39
- codellama-7b-instruct.Q3_K_L.gguf filter=lfs diff=lfs merge=lfs -text
40
- codellama-7b-instruct.Q4_K_S.gguf filter=lfs diff=lfs merge=lfs -text
41
- codellama-7b-instruct.Q4_K_M.gguf filter=lfs diff=lfs merge=lfs -text
42
- codellama-7b-instruct.Q5_K_S.gguf filter=lfs diff=lfs merge=lfs -text
43
- codellama-7b-instruct.Q5_K_M.gguf filter=lfs diff=lfs merge=lfs -text
44
- codellama-7b-instruct.Q6_K.gguf filter=lfs diff=lfs merge=lfs -text
45
- codellama-7b-instruct.Q8_0.gguf filter=lfs diff=lfs merge=lfs -text
46
- codellama-7b-instruct.Q4_0.gguf filter=lfs diff=lfs merge=lfs -text
47
- codellama-7b-instruct.Q5_0.gguf filter=lfs diff=lfs merge=lfs -text
 
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
5
  *.ftz filter=lfs diff=lfs merge=lfs -text
6
  *.gz filter=lfs diff=lfs merge=lfs -text
7
  *.h5 filter=lfs diff=lfs merge=lfs -text
8
  *.joblib filter=lfs diff=lfs merge=lfs -text
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
12
  *.npy filter=lfs diff=lfs merge=lfs -text
 
20
  *.pt filter=lfs diff=lfs merge=lfs -text
21
  *.pth filter=lfs diff=lfs merge=lfs -text
22
  *.rar filter=lfs diff=lfs merge=lfs -text
 
23
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
25
  *.tflite filter=lfs diff=lfs merge=lfs -text
26
  *.tgz filter=lfs diff=lfs merge=lfs -text
27
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,336 +1,16 @@
1
  ---
2
- license: apache-2.0
3
- datasets:
4
- - PetraAI/PetraAI
5
  language:
6
- - ar
7
  - en
8
- - ch
9
- - zh
10
- metrics:
11
- - accuracy
12
- - bertscore
13
- - bleu
14
- - chrf
15
- - code_eval
16
- - brier_score
17
  tags:
18
- - chemistry
19
- - biology
20
- - finance
21
- - legal
22
- - music
23
- - code
24
- - art
25
- - climate
26
- - medical
27
- - text-generation-inference
28
  ---
29
 
30
- ### Inference Speed
31
- > The result is generated using [this script](examples/benchmark/generation_speed.py), batch size of input is 1, decode strategy is beam search and enforce the model to generate 512 tokens, speed metric is tokens/s (the larger, the better).
32
- >
33
- > The quantized model is loaded using the setup that can gain the fastest inference speed.
34
-
35
- | model | GPU | num_beams | fp16 | gptq-int4 |
36
- |---------------|---------------|-----------|-------|-----------|
37
- | llama-7b | 1xA100-40G | 1 | 18.87 | 25.53 |
38
- | llama-7b | 1xA100-40G | 4 | 68.79 | 91.30 |
39
- | moss-moon 16b | 1xA100-40G | 1 | 12.48 | 15.25 |
40
- | moss-moon 16b | 1xA100-40G | 4 | OOM | 42.67 |
41
- | moss-moon 16b | 2xA100-40G | 1 | 06.83 | 06.78 |
42
- | moss-moon 16b | 2xA100-40G | 4 | 13.10 | 10.80 |
43
- | gpt-j 6b | 1xRTX3060-12G | 1 | OOM | 29.55 |
44
- | gpt-j 6b | 1xRTX3060-12G | 4 | OOM | 47.36 |
45
-
46
-
47
- ### Perplexity
48
- For perplexity comparison, you can turn to [here](https://github.com/qwopqwop200/GPTQ-for-LLaMa#result) and [here](https://github.com/qwopqwop200/GPTQ-for-LLaMa#gptq-vs-bitsandbytes)
49
-
50
- ## Installation
51
-
52
- ### Quick Installation
53
- You can install the latest stable release of AutoGPTQ from pip with pre-built wheels compatible with PyTorch 2.0.1:
54
-
55
- * For CUDA 11.7: `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu117/`
56
- * For CUDA 11.8: `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/`
57
- * For RoCm 5.4.2: `pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/rocm542/`
58
-
59
- **Warning:** These wheels are not expected to work on PyTorch nightly. Please install AutoGPTQ from source when using PyTorch nightly.
60
-
61
- #### disable cuda extensions
62
- By default, cuda extensions will be installed when `torch` and `cuda` is already installed in your machine, if you don't want to use them, using:
63
- ```shell
64
- BUILD_CUDA_EXT=0 pip install auto-gptq
65
- ```
66
- And to make sure `autogptq_cuda` is not ever in your virtual environment, run:
67
- ```shell
68
- pip uninstall autogptq_cuda -y
69
- ```
70
-
71
- #### to support triton speedup
72
- To integrate with `triton`, using:
73
- > warning: currently triton only supports linux; 3-bit quantization is not supported when using triton
74
-
75
- ```shell
76
- pip install auto-gptq[triton]
77
- ```
78
-
79
- ### Install from source
80
- <details>
81
- <summary>click to see details</summary>
82
-
83
- Clone the source code:
84
- ```shell
85
- git clone https://github.com/PanQiWei/AutoGPTQ.git && cd AutoGPTQ
86
- ```
87
- Then, install from source:
88
- ```shell
89
- pip install .
90
- ```
91
- Like quick installation, you can also set `BUILD_CUDA_EXT=0` to disable pytorch extension building.
92
-
93
- Use `.[triton]` if you want to integrate with triton and it's available on your operating system.
94
-
95
- To install from source for AMD GPUs supporting RoCm, please specify the `ROCM_VERSION` environment variable. The compilation can be speeded up by specifying the `PYTORCH_ROCM_ARCH` variable ([reference](https://github.com/pytorch/pytorch/blob/7b73b1e8a73a1777ebe8d2cd4487eb13da55b3ba/setup.py#L132)), for example `gfx90a` for MI200 series devices. Example:
96
-
97
- ```
98
- ROCM_VERSION=5.6 pip install .
99
- ```
100
-
101
- For RoCm systems, the packages `rocsparse-dev`, `hipsparse-dev`, `rocthrust-dev`, `rocblas-dev` and `hipblas-dev` are required to build.
102
-
103
- </details>
104
-
105
- ## Quick Tour
106
-
107
- ### Quantization and Inference
108
- > warning: this is just a showcase of the usage of basic apis in AutoGPTQ, which uses only one sample to quantize a much small model, quality of quantized model using such little samples may not good.
109
-
110
- Below is an example for the simplest use of `auto_gptq` to quantize a model and inference after quantization:
111
- ```python
112
- from transformers import AutoTokenizer, TextGenerationPipeline
113
- from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
114
- import logging
115
-
116
- logging.basicConfig(
117
- format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
118
- )
119
-
120
- pretrained_model_dir = "facebook/opt-125m"
121
- quantized_model_dir = "opt-125m-4bit"
122
-
123
- tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
124
- examples = [
125
- tokenizer(
126
- "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
127
- )
128
- ]
129
-
130
- quantize_config = BaseQuantizeConfig(
131
- bits=4, # quantize model to 4-bit
132
- group_size=128, # it is recommended to set the value to 128
133
- desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
134
- )
135
-
136
- # load un-quantized model, by default, the model will always be loaded into CPU memory
137
- model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
138
-
139
- # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
140
- model.quantize(examples)
141
-
142
- # save quantized model
143
- model.save_quantized(quantized_model_dir)
144
-
145
- # save quantized model using safetensors
146
- model.save_quantized(quantized_model_dir, use_safetensors=True)
147
-
148
- # push quantized model to Hugging Face Hub.
149
- # to use use_auth_token=True, Login first via huggingface-cli login.
150
- # or pass explcit token with: use_auth_token="hf_xxxxxxx"
151
- # (uncomment the following three lines to enable this feature)
152
- # repo_id = f"YourUserName/{quantized_model_dir}"
153
- # commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
154
- # model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True)
155
-
156
- # alternatively you can save and push at the same time
157
- # (uncomment the following three lines to enable this feature)
158
- # repo_id = f"YourUserName/{quantized_model_dir}"
159
- # commit_message = f"AutoGPTQ model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
160
- # model.push_to_hub(repo_id, save_dir=quantized_model_dir, use_safetensors=True, commit_message=commit_message, use_auth_token=True)
161
-
162
- # load quantized model to the first GPU
163
- model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0")
164
-
165
- # download quantized model from Hugging Face Hub and load to the first GPU
166
- # model = AutoGPTQForCausalLM.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False)
167
-
168
- # inference with model.generate
169
- print(tokenizer.decode(model.generate(**tokenizer("auto_gptq is", return_tensors="pt").to(model.device))[0]))
170
-
171
- # or you can also use pipeline
172
- pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
173
- print(pipeline("auto-gptq is")[0]["generated_text"])
174
- ```
175
-
176
- For more advanced features of model quantization, please reference to [this script](examples/quantization/quant_with_alpaca.py)
177
-
178
- ### Customize Model
179
- <details>
180
-
181
- <summary>Below is an example to extend `auto_gptq` to support `OPT` model, as you will see, it's very easy:</summary>
182
-
183
- ```python
184
- from auto_gptq.modeling import BaseGPTQForCausalLM
185
-
186
-
187
- class OPTGPTQForCausalLM(BaseGPTQForCausalLM):
188
- # chained attribute name of transformer layer block
189
- layers_block_name = "model.decoder.layers"
190
- # chained attribute names of other nn modules that in the same level as the transformer layer block
191
- outside_layer_modules = [
192
- "model.decoder.embed_tokens", "model.decoder.embed_positions", "model.decoder.project_out",
193
- "model.decoder.project_in", "model.decoder.final_layer_norm"
194
- ]
195
- # chained attribute names of linear layers in transformer layer module
196
- # normally, there are four sub lists, for each one the modules in it can be seen as one operation,
197
- # and the order should be the order when they are truly executed, in this case (and usually in most cases),
198
- # they are: attention q_k_v projection, attention output projection, MLP project input, MLP project output
199
- inside_layer_modules = [
200
- ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
201
- ["self_attn.out_proj"],
202
- ["fc1"],
203
- ["fc2"]
204
- ]
205
- ```
206
- After this, you can use `OPTGPTQForCausalLM.from_pretrained` and other methods as shown in Basic.
207
-
208
- </details>
209
-
210
- ### Evaluation on Downstream Tasks
211
- You can use tasks defined in `auto_gptq.eval_tasks` to evaluate model's performance on specific down-stream task before and after quantization.
212
-
213
- The predefined tasks support all causal-language-models implemented in [🤗 transformers](https://github.com/huggingface/transformers) and in this project.
214
-
215
- <details>
216
-
217
- <summary>Below is an example to evaluate `EleutherAI/gpt-j-6b` on sequence-classification task using `cardiffnlp/tweet_sentiment_multilingual` dataset:</summary>
218
-
219
- ```python
220
- from functools import partial
221
-
222
- import datasets
223
- from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
224
-
225
- from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
226
- from auto_gptq.eval_tasks import SequenceClassificationTask
227
-
228
-
229
- MODEL = "EleutherAI/gpt-j-6b"
230
- DATASET = "cardiffnlp/tweet_sentiment_multilingual"
231
- TEMPLATE = "Question:What's the sentiment of the given text? Choices are {labels}.\nText: {text}\nAnswer:"
232
- ID2LABEL = {
233
- 0: "negative",
234
- 1: "neutral",
235
- 2: "positive"
236
- }
237
- LABELS = list(ID2LABEL.values())
238
-
239
-
240
- def ds_refactor_fn(samples):
241
- text_data = samples["text"]
242
- label_data = samples["label"]
243
-
244
- new_samples = {"prompt": [], "label": []}
245
- for text, label in zip(text_data, label_data):
246
- prompt = TEMPLATE.format(labels=LABELS, text=text)
247
- new_samples["prompt"].append(prompt)
248
- new_samples["label"].append(ID2LABEL[label])
249
-
250
- return new_samples
251
-
252
-
253
- # model = AutoModelForCausalLM.from_pretrained(MODEL).eval().half().to("cuda:0")
254
- model = AutoGPTQForCausalLM.from_pretrained(MODEL, BaseQuantizeConfig())
255
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
256
-
257
- task = SequenceClassificationTask(
258
- model=model,
259
- tokenizer=tokenizer,
260
- classes=LABELS,
261
- data_name_or_path=DATASET,
262
- prompt_col_name="prompt",
263
- label_col_name="label",
264
- **{
265
- "num_samples": 1000, # how many samples will be sampled to evaluation
266
- "sample_max_len": 1024, # max tokens for each sample
267
- "block_max_len": 2048, # max tokens for each data block
268
- # function to load dataset, one must only accept data_name_or_path as input
269
- # and return datasets.Dataset
270
- "load_fn": partial(datasets.load_dataset, name="english"),
271
- # function to preprocess dataset, which is used for datasets.Dataset.map,
272
- # must return Dict[str, list] with only two keys: [prompt_col_name, label_col_name]
273
- "preprocess_fn": ds_refactor_fn,
274
- # truncate label when sample's length exceed sample_max_len
275
- "truncate_prompt": False
276
- }
277
- )
278
-
279
- # note that max_new_tokens will be automatically specified internally based on given classes
280
- print(task.run())
281
-
282
- # self-consistency
283
- print(
284
- task.run(
285
- generation_config=GenerationConfig(
286
- num_beams=3,
287
- num_return_sequences=3,
288
- do_sample=True
289
- )
290
- )
291
- )
292
- ```
293
-
294
- </details>
295
-
296
- ## Learn More
297
- [tutorials](docs/tutorial) provide step-by-step guidance to integrate `auto_gptq` with your own project and some best practice principles.
298
-
299
- [examples](examples/README.md) provide plenty of example scripts to use `auto_gptq` in different ways.
300
-
301
- ## Supported Models
302
-
303
- > you can use `model.config.model_type` to compare with the table below to check whether the model you use is supported by `auto_gptq`.
304
- >
305
- > for example, model_type of `WizardLM`, `vicuna` and `gpt4all` are all `llama`, hence they are all supported by `auto_gptq`.
306
-
307
- | model type | quantization | inference | peft-lora | peft-ada-lora | peft-adaption_prompt |
308
- |------------------------------------|--------------|-----------|-----------|---------------|-------------------------------------------------------------------------------------------------|
309
- | bloom | ✅ | ✅ | ✅ | ✅ | |
310
- | gpt2 | ✅ | ✅ | ✅ | ✅ | |
311
- | gpt_neox | ✅ | ✅ | ✅ | ✅ | ✅[requires this peft branch](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
312
- | gptj | ✅ | ✅ | ✅ | ✅ | ✅[requires this peft branch](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
313
- | llama | ✅ | ✅ | ✅ | ✅ | ✅ |
314
- | moss | ✅ | ✅ | ✅ | ✅ | ✅[requires this peft branch](https://github.com/PanQiWei/peft/tree/multi_modal_adaption_prompt) |
315
- | opt | ✅ | ✅ | ✅ | ✅ | |
316
- | gpt_bigcode | ✅ | ✅ | ✅ | ✅ | |
317
- | codegen | ✅ | ✅ | ✅ | ✅ | |
318
- | falcon(RefinedWebModel/RefinedWeb) | ✅ | ✅ | ✅ | ✅ | |
319
-
320
- ## Supported Evaluation Tasks
321
- Currently, `auto_gptq` supports: `LanguageModelingTask`, `SequenceClassificationTask` and `TextSummarizationTask`; more Tasks will come soon!
322
-
323
- ## Running tests
324
-
325
- Tests can be run with:
326
-
327
- ```
328
- pytest tests/ -s
329
- ```
330
-
331
- ## Acknowledgement
332
- - Specially thanks **Elias Frantar**, **Saleh Ashkboos**, **Torsten Hoefler** and **Dan Alistarh** for proposing **GPTQ** algorithm and open source the [code](https://github.com/IST-DASLab/gptq).
333
- - Specially thanks **qwopqwop200**, for code in this project that relevant to quantization are mainly referenced from [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda).
334
-
335
 
336
- [![Star History Chart](https://api.star-history.com/svg?repos=PanQiwei/AutoGPTQ&type=Date)](https://star-history.com/#PanQiWei/AutoGPTQ&Date)
 
1
  ---
 
 
 
2
  language:
 
3
  - en
 
 
 
 
 
 
 
 
 
4
  tags:
5
+ - text-classification
6
+ - emotion
7
+ - endpoints-template
8
+ license: apache-2.0
9
+ datasets:
10
+ - emotion
11
+ metrics:
12
+ - Accuracy, F1 Score
 
 
13
  ---
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # Fork of [bhadresh-savani/distilbert-base-uncased-emotion](https://huggingface.co/bhadresh-savani/distilbert-base-uncased-emotion)
config.json CHANGED
@@ -1,3 +1,39 @@
1
  {
2
- "model_type": "llama"
3
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  {
2
+ "_name_or_path": "./",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "DistilBertForSequenceClassification"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "id2label": {
12
+ "0": "sadness",
13
+ "1": "joy",
14
+ "2": "love",
15
+ "3": "anger",
16
+ "4": "fear",
17
+ "5": "surprise"
18
+ },
19
+ "initializer_range": 0.02,
20
+ "label2id": {
21
+ "anger": 3,
22
+ "fear": 4,
23
+ "joy": 1,
24
+ "love": 2,
25
+ "sadness": 0,
26
+ "surprise": 5
27
+ },
28
+ "max_position_embeddings": 512,
29
+ "model_type": "distilbert",
30
+ "n_heads": 12,
31
+ "n_layers": 6,
32
+ "pad_token_id": 0,
33
+ "qa_dropout": 0.1,
34
+ "seq_classif_dropout": 0.2,
35
+ "sinusoidal_pos_embds": false,
36
+ "tie_weights_": true,
37
+ "transformers_version": "4.11.0.dev0",
38
+ "vocab_size": 30522
39
+ }
handler.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import pipeline
3
+ import holidays
4
+
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path=""):
8
+ self.pipeline = pipeline("text-classification", model=path)
9
+ self.holidays = holidays.US()
10
+
11
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
12
+ """
13
+ data args:
14
+ inputs (:obj: `str`)
15
+ date (:obj: `str`)
16
+ Return:
17
+ A :obj:`list` | `dict`: will be serialized and returned
18
+ """
19
+ # get inputs
20
+ inputs = data.pop("inputs", data)
21
+ # get additional date field
22
+ date = data.pop("date", None)
23
+
24
+ # check if date exists and if it is a holiday
25
+ if date is not None and date in self.holidays:
26
+ return [{"label": "happy", "score": 1}]
27
+
28
+ # run normal prediction
29
+ prediction = self.pipeline(inputs)
30
+ return prediction
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5aa7398d830fcc94f95af88d7cc3013813668cfc58a07d75a8116cfd8af75c4d
3
+ size 267875479
requirements.txt CHANGED
@@ -1,6 +1 @@
1
- pandas
2
- ninja
3
- fastparquet
4
- torch>=2.0.1
5
- safetensors>=0.3.2
6
- sentencepiece>=0.1.97
 
1
+ holidaysholidays
 
 
 
 
 
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "distilbert-base-uncased"}
vocab.txt ADDED
The diff for this file is too large to render. See raw diff