# README This is a facebook/bart-large-mnli model finetuned on a dataset created from EU taxonomie. The model is in ONNX format. The model has a single neuron output for binary classification as to predict whether a keyword is semantically contained/relevant within a given topic. The topics can be from any domain. | Tasks | Positive example | Negative example | | ---------- | ------------------------------- | ---------------- | | Topic | Fashion industry | Healthcare | | Keyword | Pollution of textile production | Football athlete | | Prediction | Relevant | Irrelevant | How to load the model locally and do inference: ```python from optimum.onnxruntime import ORTModelForSequenceClassification from transformers import BartTokenizerFast, BartConfig from torch import nn config = BartConfig(num_labels = 1) model_checkpoint = "cocorooxinnn/optim_finetuned_bart_inclusion" ort_model = ORTModelForSequenceClassification.from_pretrained(model_checkpoint, config = config, provider="CUDAExecutionProvider",) tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-large-mnli") assert str(device) == "cuda", "Error: CUDA device not available. The inference is optimized to run on a GPU. Please run the program on a GPU." m = nn.Sigmoid #Simple NLI template for premise, hypothesis template = "They are talking about " premise, hypothesis = (template + "Fashion industry", template + "Pollution of textile production ") x = tokenizer(premise, hypothesis,return_tensors='pt').to(ort_model.device) prediction = m(model(**x).logits).squeeze() > 0.5 # batch_size, 1 # True ``` ## Training procedure ### Training hyperparameters The following hyperparameters were used during training: - learning_rate: 2e-4 - train_batch_size: 64 - eval_batch_size: 64 - seed: 42 - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-06 - lr_scheduler_type: linear - lr_scheduler_warmup_steps: 20 - num_epochs: 2 ### Training results (F-score is calculated on validation set) | Step | Training Loss | Validation Loss | F1 | Precision | Recall | Accuracy | | ---- | ------------- | --------------- | -------- | --------- | -------- | -------- | | 50 | 0.441800 | 0.285273 | 0.882523 | 0.892977 | 0.872311 | 0.883930 | | 100 | 0.328400 | 0.259236 | 0.892380 | 0.903460 | 0.881568 | 0.893727 | | 150 | 0.287400 | 0.257763 | 0.908518 | 0.877091 | 0.942282 | 0.905157 | | 200 | 0.286300 | 0.243944 | 0.909042 | 0.898643 | 0.919684 | 0.908015 | | 250 | 0.275500 | 0.239769 | 0.894515 | 0.925225 | 0.865777 | 0.897945 | | 300 | 0.271800 | 0.222483 | 0.912901 | 0.912653 | 0.913150 | 0.912913 | | 350 | 0.245500 | 0.221774 | 0.916039 | 0.905560 | 0.926763 | 0.915090 | | 400 | 0.250700 | 0.219120 | 0.912324 | 0.924044 | 0.900898 | 0.913458 | | 450 | 0.249000 | 0.211039 | 0.914482 | 0.922204 | 0.906888 | 0.915227 | | 500 | 0.241500 | 0.207655 | 0.927005 | 0.910691 | 0.943915 | 0.925704 | | 550 | 0.249900 | 0.197901 | 0.924230 | 0.925239 | 0.923224 | 0.924343 | | 600 | 0.236000 | 0.202164 | 0.921937 | 0.929204 | 0.914784 | 0.922574 | | 650 | 0.196100 | 0.192816 | 0.931687 | 0.918740 | 0.945004 | 0.930739 | | 700 | 0.214800 | 0.206045 | 0.930494 | 0.894499 | 0.969507 | 0.927609 | | 750 | 0.223600 | 0.186433 | 0.931180 | 0.928533 | 0.933842 | 0.931011 | | 800 | 0.223900 | 0.189542 | 0.933564 | 0.911757 | 0.956439 | 0.931964 | | 850 | 0.197500 | 0.191664 | 0.930473 | 0.928204 | 0.932753 | 0.930331 | | 900 | 0.194600 | 0.185483 | 0.930797 | 0.922460 | 0.939287 | 0.930195 | | 950 | 0.190200 | 0.183808 | 0.934791 | 0.916100 | 0.954261 | 0.933460 | | 1000 | 0.189700 | 0.181666 | 0.934212 | 0.923404 | 0.945276 | 0.933460 | | 1050 | 0.199300 | 0.177857 | 0.933693 | 0.924473 | 0.943098 | 0.933052 | TrainOutput(global_step=1072, training_loss=0.2457642003671447, metrics={'train_runtime': 3750.3603, 'train_samples_per_second': 18.289, 'train_steps_per_second': 0.286, 'total_flos': 7425156147297000.0, 'train_loss': 0.2457642003671447, 'epoch': 2.0}) ## Evaluation on testset | Precision | Recall | F-score | | --------- | ------ | ------- | | 0.94 | 0.94 | 0.94 | ### Framework versions - PEFT 0.10.0 - Transformers 4.41.2 - Pytorch 2.2.0+cu121 - Datasets 2.19.1 - Tokenizers 0.19.1