cocorooxinnn commited on
Commit
30e0662
1 Parent(s): 406337d

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +91 -0
README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # README
2
+
3
+ This is a facebook/bart-large-mnli model finetuned on a dataset created from EU taxonomie. The model is in ONNX format.
4
+
5
+ The model has a single neuron output for binary classification as to predict whether a keyword is semantically contained/relevant within a given topic. The topics can be from any domain.
6
+
7
+ | Tasks | Positive example | Negative example |
8
+ | ---------- | ------------------------------- | ---------------- |
9
+ | Topic | Fashion industry | Healthcare |
10
+ | Keyword | Pollution of textile production | Football athlete |
11
+ | Prediction | Relevant | Irrelevant |
12
+
13
+ How to load the model locally and do inference:
14
+
15
+ ```python
16
+ from optimum.onnxruntime import ORTModelForSequenceClassification
17
+ from transformers import BartTokenizerFast, BartConfig
18
+ from torch import nn
19
+
20
+ config = BartConfig(num_labels = 1)
21
+ model_checkpoint = "cocorooxinnn/optim_finetuned_bart_inclusion"
22
+ ort_model = ORTModelForSequenceClassification.from_pretrained(model_checkpoint,
23
+ config = config, provider="CUDAExecutionProvider",)
24
+ tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-large-mnli")
25
+
26
+ assert str(device) == "cuda", "Error: CUDA device not available. The inference is optimized to run on a GPU. Please run the program on a GPU."
27
+ m = nn.Sigmoid
28
+ #Simple NLI template for premise, hypothesis
29
+ template = "They are talking about "
30
+ premise, hypothesis = (template + "Fashion industry", template + "Pollution of textile production ")
31
+ x = tokenizer(premise, hypothesis,return_tensors='pt').to(ort_model.device)
32
+ prediction = m(model(**x).logits).squeeze() > 0.5 # batch_size, 1
33
+ # True
34
+ ```
35
+
36
+ ## Training procedure
37
+
38
+ ### Training hyperparameters
39
+
40
+ The following hyperparameters were used during training:
41
+
42
+ - learning_rate: 2e-4
43
+ - train_batch_size: 64
44
+ - eval_batch_size: 64
45
+ - seed: 42
46
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-06
47
+ - lr_scheduler_type: linear
48
+ - lr_scheduler_warmup_steps: 20
49
+ - num_epochs: 2
50
+
51
+ ### Training results (F-score is calculated on validation set)
52
+
53
+ | Step | Training Loss | Validation Loss | F1 | Precision | Recall | Accuracy |
54
+ | ---- | ------------- | --------------- | -------- | --------- | -------- | -------- |
55
+ | 50 | 0.441800 | 0.285273 | 0.882523 | 0.892977 | 0.872311 | 0.883930 |
56
+ | 100 | 0.328400 | 0.259236 | 0.892380 | 0.903460 | 0.881568 | 0.893727 |
57
+ | 150 | 0.287400 | 0.257763 | 0.908518 | 0.877091 | 0.942282 | 0.905157 |
58
+ | 200 | 0.286300 | 0.243944 | 0.909042 | 0.898643 | 0.919684 | 0.908015 |
59
+ | 250 | 0.275500 | 0.239769 | 0.894515 | 0.925225 | 0.865777 | 0.897945 |
60
+ | 300 | 0.271800 | 0.222483 | 0.912901 | 0.912653 | 0.913150 | 0.912913 |
61
+ | 350 | 0.245500 | 0.221774 | 0.916039 | 0.905560 | 0.926763 | 0.915090 |
62
+ | 400 | 0.250700 | 0.219120 | 0.912324 | 0.924044 | 0.900898 | 0.913458 |
63
+ | 450 | 0.249000 | 0.211039 | 0.914482 | 0.922204 | 0.906888 | 0.915227 |
64
+ | 500 | 0.241500 | 0.207655 | 0.927005 | 0.910691 | 0.943915 | 0.925704 |
65
+ | 550 | 0.249900 | 0.197901 | 0.924230 | 0.925239 | 0.923224 | 0.924343 |
66
+ | 600 | 0.236000 | 0.202164 | 0.921937 | 0.929204 | 0.914784 | 0.922574 |
67
+ | 650 | 0.196100 | 0.192816 | 0.931687 | 0.918740 | 0.945004 | 0.930739 |
68
+ | 700 | 0.214800 | 0.206045 | 0.930494 | 0.894499 | 0.969507 | 0.927609 |
69
+ | 750 | 0.223600 | 0.186433 | 0.931180 | 0.928533 | 0.933842 | 0.931011 |
70
+ | 800 | 0.223900 | 0.189542 | 0.933564 | 0.911757 | 0.956439 | 0.931964 |
71
+ | 850 | 0.197500 | 0.191664 | 0.930473 | 0.928204 | 0.932753 | 0.930331 |
72
+ | 900 | 0.194600 | 0.185483 | 0.930797 | 0.922460 | 0.939287 | 0.930195 |
73
+ | 950 | 0.190200 | 0.183808 | 0.934791 | 0.916100 | 0.954261 | 0.933460 |
74
+ | 1000 | 0.189700 | 0.181666 | 0.934212 | 0.923404 | 0.945276 | 0.933460 |
75
+ | 1050 | 0.199300 | 0.177857 | 0.933693 | 0.924473 | 0.943098 | 0.933052 |
76
+
77
+ TrainOutput(global_step=1072, training_loss=0.2457642003671447, metrics={'train_runtime': 3750.3603, 'train_samples_per_second': 18.289, 'train_steps_per_second': 0.286, 'total_flos': 7425156147297000.0, 'train_loss': 0.2457642003671447, 'epoch': 2.0})
78
+
79
+ ## Evaluation on testset
80
+
81
+ | Precision | Recall | F-score |
82
+ | --------- | ------ | ------- |
83
+ | 0.94 | 0.94 | 0.94 |
84
+
85
+ ### Framework versions
86
+
87
+ - PEFT 0.10.0
88
+ - Transformers 4.41.2
89
+ - Pytorch 2.2.0+cu121
90
+ - Datasets 2.19.1
91
+ - Tokenizers 0.19.1