cocorooxinnn
commited on
Commit
•
30e0662
1
Parent(s):
406337d
Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# README
|
2 |
+
|
3 |
+
This is a facebook/bart-large-mnli model finetuned on a dataset created from EU taxonomie. The model is in ONNX format.
|
4 |
+
|
5 |
+
The model has a single neuron output for binary classification as to predict whether a keyword is semantically contained/relevant within a given topic. The topics can be from any domain.
|
6 |
+
|
7 |
+
| Tasks | Positive example | Negative example |
|
8 |
+
| ---------- | ------------------------------- | ---------------- |
|
9 |
+
| Topic | Fashion industry | Healthcare |
|
10 |
+
| Keyword | Pollution of textile production | Football athlete |
|
11 |
+
| Prediction | Relevant | Irrelevant |
|
12 |
+
|
13 |
+
How to load the model locally and do inference:
|
14 |
+
|
15 |
+
```python
|
16 |
+
from optimum.onnxruntime import ORTModelForSequenceClassification
|
17 |
+
from transformers import BartTokenizerFast, BartConfig
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
config = BartConfig(num_labels = 1)
|
21 |
+
model_checkpoint = "cocorooxinnn/optim_finetuned_bart_inclusion"
|
22 |
+
ort_model = ORTModelForSequenceClassification.from_pretrained(model_checkpoint,
|
23 |
+
config = config, provider="CUDAExecutionProvider",)
|
24 |
+
tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-large-mnli")
|
25 |
+
|
26 |
+
assert str(device) == "cuda", "Error: CUDA device not available. The inference is optimized to run on a GPU. Please run the program on a GPU."
|
27 |
+
m = nn.Sigmoid
|
28 |
+
#Simple NLI template for premise, hypothesis
|
29 |
+
template = "They are talking about "
|
30 |
+
premise, hypothesis = (template + "Fashion industry", template + "Pollution of textile production ")
|
31 |
+
x = tokenizer(premise, hypothesis,return_tensors='pt').to(ort_model.device)
|
32 |
+
prediction = m(model(**x).logits).squeeze() > 0.5 # batch_size, 1
|
33 |
+
# True
|
34 |
+
```
|
35 |
+
|
36 |
+
## Training procedure
|
37 |
+
|
38 |
+
### Training hyperparameters
|
39 |
+
|
40 |
+
The following hyperparameters were used during training:
|
41 |
+
|
42 |
+
- learning_rate: 2e-4
|
43 |
+
- train_batch_size: 64
|
44 |
+
- eval_batch_size: 64
|
45 |
+
- seed: 42
|
46 |
+
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-06
|
47 |
+
- lr_scheduler_type: linear
|
48 |
+
- lr_scheduler_warmup_steps: 20
|
49 |
+
- num_epochs: 2
|
50 |
+
|
51 |
+
### Training results (F-score is calculated on validation set)
|
52 |
+
|
53 |
+
| Step | Training Loss | Validation Loss | F1 | Precision | Recall | Accuracy |
|
54 |
+
| ---- | ------------- | --------------- | -------- | --------- | -------- | -------- |
|
55 |
+
| 50 | 0.441800 | 0.285273 | 0.882523 | 0.892977 | 0.872311 | 0.883930 |
|
56 |
+
| 100 | 0.328400 | 0.259236 | 0.892380 | 0.903460 | 0.881568 | 0.893727 |
|
57 |
+
| 150 | 0.287400 | 0.257763 | 0.908518 | 0.877091 | 0.942282 | 0.905157 |
|
58 |
+
| 200 | 0.286300 | 0.243944 | 0.909042 | 0.898643 | 0.919684 | 0.908015 |
|
59 |
+
| 250 | 0.275500 | 0.239769 | 0.894515 | 0.925225 | 0.865777 | 0.897945 |
|
60 |
+
| 300 | 0.271800 | 0.222483 | 0.912901 | 0.912653 | 0.913150 | 0.912913 |
|
61 |
+
| 350 | 0.245500 | 0.221774 | 0.916039 | 0.905560 | 0.926763 | 0.915090 |
|
62 |
+
| 400 | 0.250700 | 0.219120 | 0.912324 | 0.924044 | 0.900898 | 0.913458 |
|
63 |
+
| 450 | 0.249000 | 0.211039 | 0.914482 | 0.922204 | 0.906888 | 0.915227 |
|
64 |
+
| 500 | 0.241500 | 0.207655 | 0.927005 | 0.910691 | 0.943915 | 0.925704 |
|
65 |
+
| 550 | 0.249900 | 0.197901 | 0.924230 | 0.925239 | 0.923224 | 0.924343 |
|
66 |
+
| 600 | 0.236000 | 0.202164 | 0.921937 | 0.929204 | 0.914784 | 0.922574 |
|
67 |
+
| 650 | 0.196100 | 0.192816 | 0.931687 | 0.918740 | 0.945004 | 0.930739 |
|
68 |
+
| 700 | 0.214800 | 0.206045 | 0.930494 | 0.894499 | 0.969507 | 0.927609 |
|
69 |
+
| 750 | 0.223600 | 0.186433 | 0.931180 | 0.928533 | 0.933842 | 0.931011 |
|
70 |
+
| 800 | 0.223900 | 0.189542 | 0.933564 | 0.911757 | 0.956439 | 0.931964 |
|
71 |
+
| 850 | 0.197500 | 0.191664 | 0.930473 | 0.928204 | 0.932753 | 0.930331 |
|
72 |
+
| 900 | 0.194600 | 0.185483 | 0.930797 | 0.922460 | 0.939287 | 0.930195 |
|
73 |
+
| 950 | 0.190200 | 0.183808 | 0.934791 | 0.916100 | 0.954261 | 0.933460 |
|
74 |
+
| 1000 | 0.189700 | 0.181666 | 0.934212 | 0.923404 | 0.945276 | 0.933460 |
|
75 |
+
| 1050 | 0.199300 | 0.177857 | 0.933693 | 0.924473 | 0.943098 | 0.933052 |
|
76 |
+
|
77 |
+
TrainOutput(global_step=1072, training_loss=0.2457642003671447, metrics={'train_runtime': 3750.3603, 'train_samples_per_second': 18.289, 'train_steps_per_second': 0.286, 'total_flos': 7425156147297000.0, 'train_loss': 0.2457642003671447, 'epoch': 2.0})
|
78 |
+
|
79 |
+
## Evaluation on testset
|
80 |
+
|
81 |
+
| Precision | Recall | F-score |
|
82 |
+
| --------- | ------ | ------- |
|
83 |
+
| 0.94 | 0.94 | 0.94 |
|
84 |
+
|
85 |
+
### Framework versions
|
86 |
+
|
87 |
+
- PEFT 0.10.0
|
88 |
+
- Transformers 4.41.2
|
89 |
+
- Pytorch 2.2.0+cu121
|
90 |
+
- Datasets 2.19.1
|
91 |
+
- Tokenizers 0.19.1
|