Edit model card

Action_agent

This model is a fine-tuned version of google/vit-base-patch16-224-in21k on the agent_action_class dataset. It achieves the following results on the evaluation set:

  • Loss: 0.9962

  • Accuracy: 0.8243

  • Confusion Matrix: [[39, 3, 0, 0, 2, 1, 0, 1, 3, 3], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [1, 0, 38, 2, 1, 4, 0, 5, 0, 0], [4, 1, 0, 39, 0, 3, 0, 0, 0, 8], [1, 1, 2, 1, 50, 0, 0, 0, 0, 1], [0, 0, 7, 1, 1, 44, 1, 0, 0, 2], [3, 0, 0, 1, 1, 0, 55, 0, 2, 1], [0, 0, 3, 1, 0, 0, 0, 52, 0, 0], [2, 9, 0, 0, 0, 0, 9, 1, 39, 0], [0, 0, 0, 2, 0, 1, 0, 1, 0, 56]]

  • Classification Report: precision recall f1-score support

         0     0.7800    0.7500    0.7647        52
         1     0.8028    0.9500    0.8702        60
         2     0.7600    0.7451    0.7525        51
         3     0.8298    0.7091    0.7647        55
         4     0.9091    0.8929    0.9009        56
         5     0.8302    0.7857    0.8073        56
         6     0.8333    0.8730    0.8527        63
         7     0.8667    0.9286    0.8966        56
         8     0.8667    0.6500    0.7429        60
         9     0.7778    0.9333    0.8485        60
    

    accuracy 0.8243 569 macro avg 0.8256 0.8218 0.8201 569

weighted avg 0.8264 0.8243 0.8216 569

Model description

More information needed

Intended uses & limitations

More information needed

Training and evaluation data

More information needed

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 1e-05
  • train_batch_size: 32
  • eval_batch_size: 8
  • seed: 42
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • num_epochs: 10

Training results

Training Loss Epoch Step Validation Loss Accuracy Confusion Matrix Classification Report
2.1982 0.75 100 2.1583 0.4851 [[2, 3, 2, 1, 3, 1, 7, 15, 10, 8], [1, 52, 0, 0, 2, 0, 0, 2, 2, 1], [1, 0, 15, 0, 5, 0, 3, 23, 3, 1], [2, 1, 8, 12, 5, 0, 6, 6, 1, 14], [0, 2, 9, 1, 30, 2, 2, 3, 2, 5], [0, 2, 6, 2, 5, 16, 2, 16, 4, 3], [0, 7, 0, 1, 5, 2, 27, 1, 12, 8], [0, 0, 1, 0, 0, 0, 1, 54, 0, 0], [0, 11, 1, 0, 3, 2, 5, 7, 31, 0], [0, 3, 4, 1, 4, 1, 1, 6, 3, 37]] precision recall f1-score support
       0     0.3333    0.0385    0.0690        52
       1     0.6420    0.8667    0.7376        60
       2     0.3261    0.2941    0.3093        51
       3     0.6667    0.2182    0.3288        55
       4     0.4839    0.5357    0.5085        56
       5     0.6667    0.2857    0.4000        56
       6     0.5000    0.4286    0.4615        63
       7     0.4060    0.9643    0.5714        56
       8     0.4559    0.5167    0.4844        60
       9     0.4805    0.6167    0.5401        60

accuracy                         0.4851       569

macro avg 0.4961 0.4765 0.4411 569 weighted avg 0.4991 0.4851 0.4484 569 | | 1.988 | 1.49 | 200 | 1.9350 | 0.6257 | [[11, 6, 2, 0, 7, 1, 3, 10, 7, 5], [0, 58, 0, 0, 1, 0, 0, 0, 1, 0], [1, 1, 19, 0, 4, 1, 1, 24, 0, 0], [1, 1, 5, 16, 3, 0, 6, 7, 0, 16], [1, 1, 1, 0, 50, 0, 2, 0, 0, 1], [1, 0, 11, 0, 6, 25, 0, 11, 0, 2], [2, 8, 1, 1, 3, 1, 38, 2, 5, 2], [0, 0, 1, 0, 0, 0, 0, 55, 0, 0], [1, 12, 0, 0, 1, 1, 5, 6, 34, 0], [1, 0, 2, 3, 2, 0, 0, 2, 0, 50]] | precision recall f1-score support

       0     0.5789    0.2115    0.3099        52
       1     0.6667    0.9667    0.7891        60
       2     0.4524    0.3725    0.4086        51
       3     0.8000    0.2909    0.4267        55
       4     0.6494    0.8929    0.7519        56
       5     0.8621    0.4464    0.5882        56
       6     0.6909    0.6032    0.6441        63
       7     0.4701    0.9821    0.6358        56
       8     0.7234    0.5667    0.6355        60
       9     0.6579    0.8333    0.7353        60

accuracy                         0.6257       569

macro avg 0.6552 0.6166 0.5925 569 weighted avg 0.6583 0.6257 0.5997 569 | | 1.7347 | 2.24 | 300 | 1.6937 | 0.7223 | [[28, 4, 2, 1, 4, 1, 1, 1, 6, 4], [0, 58, 0, 0, 0, 0, 1, 0, 1, 0], [3, 0, 28, 0, 1, 1, 1, 16, 0, 1], [2, 2, 2, 29, 1, 0, 2, 2, 0, 15], [2, 1, 1, 0, 49, 0, 1, 0, 0, 2], [1, 0, 6, 0, 3, 35, 1, 8, 0, 2], [4, 5, 1, 1, 1, 0, 38, 1, 10, 2], [0, 0, 0, 0, 0, 0, 0, 56, 0, 0], [6, 11, 0, 0, 1, 0, 5, 2, 35, 0], [0, 0, 2, 2, 0, 0, 0, 1, 0, 55]] | precision recall f1-score support

       0     0.6087    0.5385    0.5714        52
       1     0.7160    0.9667    0.8227        60
       2     0.6667    0.5490    0.6022        51
       3     0.8788    0.5273    0.6591        55
       4     0.8167    0.8750    0.8448        56
       5     0.9459    0.6250    0.7527        56
       6     0.7600    0.6032    0.6726        63
       7     0.6437    1.0000    0.7832        56
       8     0.6731    0.5833    0.6250        60
       9     0.6790    0.9167    0.7801        60

accuracy                         0.7223       569

macro avg 0.7389 0.7185 0.7114 569 weighted avg 0.7394 0.7223 0.7136 569 | | 1.5713 | 2.99 | 400 | 1.4857 | 0.7434 | [[26, 6, 2, 1, 5, 1, 0, 2, 5, 4], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [2, 0, 29, 1, 2, 2, 2, 13, 0, 0], [3, 1, 4, 32, 1, 1, 0, 1, 0, 12], [1, 1, 1, 0, 49, 0, 1, 0, 0, 3], [1, 0, 6, 0, 4, 41, 0, 2, 0, 2], [3, 5, 1, 0, 1, 0, 42, 0, 8, 3], [0, 0, 0, 1, 0, 0, 0, 55, 0, 0], [4, 11, 0, 0, 0, 0, 8, 2, 35, 0], [0, 0, 2, 0, 0, 0, 0, 1, 0, 57]] | precision recall f1-score support

       0     0.6500    0.5000    0.5652        52
       1     0.7037    0.9500    0.8085        60
       2     0.6444    0.5686    0.6042        51
       3     0.9143    0.5818    0.7111        55
       4     0.7903    0.8750    0.8305        56
       5     0.9111    0.7321    0.8119        56
       6     0.7778    0.6667    0.7179        63
       7     0.7237    0.9821    0.8333        56
       8     0.7143    0.5833    0.6422        60
       9     0.6951    0.9500    0.8028        60

accuracy                         0.7434       569

macro avg 0.7525 0.7390 0.7328 569 weighted avg 0.7532 0.7434 0.7353 569 | | 1.3821 | 3.73 | 500 | 1.3477 | 0.7575 | [[30, 4, 0, 3, 4, 1, 0, 2, 4, 4], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [2, 0, 30, 4, 1, 2, 1, 10, 0, 1], [3, 2, 2, 27, 0, 1, 0, 2, 0, 18], [1, 1, 1, 0, 49, 0, 1, 0, 0, 3], [1, 0, 5, 0, 1, 44, 1, 1, 0, 3], [4, 0, 1, 1, 1, 0, 49, 0, 3, 4], [0, 0, 2, 1, 0, 0, 0, 53, 0, 0], [3, 11, 0, 0, 0, 0, 10, 2, 34, 0], [0, 0, 1, 0, 0, 0, 0, 1, 0, 58]] | precision recall f1-score support

       0     0.6818    0.5769    0.6250        52
       1     0.7600    0.9500    0.8444        60
       2     0.7143    0.5882    0.6452        51
       3     0.7500    0.4909    0.5934        55
       4     0.8750    0.8750    0.8750        56
       5     0.9167    0.7857    0.8462        56
       6     0.7778    0.7778    0.7778        63
       7     0.7465    0.9464    0.8346        56
       8     0.8095    0.5667    0.6667        60
       9     0.6304    0.9667    0.7632        60

accuracy                         0.7575       569

macro avg 0.7662 0.7524 0.7471 569 weighted avg 0.7667 0.7575 0.7498 569 | | 1.3065 | 4.48 | 600 | 1.2437 | 0.7856 | [[33, 4, 0, 1, 3, 1, 0, 2, 4, 4], [0, 56, 0, 0, 0, 0, 1, 0, 2, 1], [1, 0, 29, 5, 1, 2, 1, 12, 0, 0], [2, 1, 1, 36, 0, 3, 0, 2, 0, 10], [1, 1, 1, 1, 50, 0, 0, 0, 0, 2], [1, 0, 4, 1, 1, 42, 1, 4, 0, 2], [3, 0, 0, 0, 1, 0, 53, 0, 3, 3], [0, 0, 0, 1, 0, 0, 0, 55, 0, 0], [4, 9, 0, 0, 0, 0, 9, 1, 37, 0], [0, 0, 0, 2, 0, 1, 0, 1, 0, 56]] | precision recall f1-score support

       0     0.7333    0.6346    0.6804        52
       1     0.7887    0.9333    0.8550        60
       2     0.8286    0.5686    0.6744        51
       3     0.7660    0.6545    0.7059        55
       4     0.8929    0.8929    0.8929        56
       5     0.8571    0.7500    0.8000        56
       6     0.8154    0.8413    0.8281        63
       7     0.7143    0.9821    0.8271        56
       8     0.8043    0.6167    0.6981        60
       9     0.7179    0.9333    0.8116        60

accuracy                         0.7856       569

macro avg 0.7919 0.7807 0.7773 569 weighted avg 0.7918 0.7856 0.7799 569 | | 1.2329 | 5.22 | 700 | 1.1645 | 0.7909 | [[34, 4, 0, 1, 3, 1, 0, 1, 4, 4], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [1, 0, 33, 5, 1, 3, 1, 7, 0, 0], [3, 1, 1, 31, 1, 2, 0, 1, 0, 15], [1, 1, 1, 1, 50, 0, 0, 0, 0, 2], [1, 0, 7, 1, 2, 43, 0, 0, 0, 2], [2, 0, 0, 0, 1, 0, 56, 0, 1, 3], [0, 0, 2, 1, 0, 0, 0, 53, 0, 0], [2, 11, 0, 0, 0, 0, 10, 1, 36, 0], [0, 0, 0, 1, 0, 1, 0, 1, 0, 57]] | precision recall f1-score support

       0     0.7727    0.6538    0.7083        52
       1     0.7703    0.9500    0.8507        60
       2     0.7500    0.6471    0.6947        51
       3     0.7561    0.5636    0.6458        55
       4     0.8621    0.8929    0.8772        56
       5     0.8600    0.7679    0.8113        56
       6     0.8235    0.8889    0.8550        63
       7     0.8281    0.9464    0.8833        56
       8     0.8571    0.6000    0.7059        60
       9     0.6786    0.9500    0.7917        60

accuracy                         0.7909       569

macro avg 0.7959 0.7861 0.7824 569 weighted avg 0.7963 0.7909 0.7848 569 | | 1.1736 | 5.97 | 800 | 1.1159 | 0.7891 | [[35, 4, 0, 0, 2, 1, 1, 1, 4, 4], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [2, 0, 35, 2, 1, 3, 1, 7, 0, 0], [3, 1, 0, 34, 0, 3, 0, 1, 0, 13], [1, 1, 2, 1, 49, 0, 0, 0, 0, 2], [1, 0, 7, 1, 1, 43, 1, 0, 0, 2], [3, 0, 0, 0, 1, 0, 51, 0, 4, 4], [0, 0, 3, 1, 0, 0, 0, 52, 0, 0], [4, 10, 0, 0, 0, 0, 8, 1, 37, 0], [0, 0, 0, 3, 0, 0, 0, 1, 0, 56]] | precision recall f1-score support

       0     0.7143    0.6731    0.6931        52
       1     0.7808    0.9500    0.8571        60
       2     0.7447    0.6863    0.7143        51
       3     0.8095    0.6182    0.7010        55
       4     0.9074    0.8750    0.8909        56
       5     0.8600    0.7679    0.8113        56
       6     0.8095    0.8095    0.8095        63
       7     0.8254    0.9286    0.8739        56
       8     0.8043    0.6167    0.6981        60
       9     0.6829    0.9333    0.7887        60

accuracy                         0.7891       569

macro avg 0.7939 0.7858 0.7838 569 weighted avg 0.7942 0.7891 0.7855 569 | | 1.1396 | 6.72 | 900 | 1.0749 | 0.8067 | [[39, 3, 0, 0, 1, 1, 0, 2, 3, 3], [1, 56, 0, 0, 0, 0, 1, 0, 1, 1], [2, 0, 38, 1, 1, 3, 0, 6, 0, 0], [3, 1, 1, 33, 0, 3, 0, 1, 0, 13], [1, 1, 2, 1, 50, 0, 0, 0, 0, 1], [0, 0, 7, 1, 1, 44, 1, 0, 0, 2], [3, 0, 0, 0, 1, 0, 53, 0, 2, 4], [0, 0, 3, 1, 0, 0, 0, 52, 0, 0], [5, 9, 0, 0, 0, 0, 8, 1, 37, 0], [0, 0, 0, 1, 0, 1, 0, 1, 0, 57]] | precision recall f1-score support

       0     0.7222    0.7500    0.7358        52
       1     0.8000    0.9333    0.8615        60
       2     0.7451    0.7451    0.7451        51
       3     0.8684    0.6000    0.7097        55
       4     0.9259    0.8929    0.9091        56
       5     0.8462    0.7857    0.8148        56
       6     0.8413    0.8413    0.8413        63
       7     0.8254    0.9286    0.8739        56
       8     0.8605    0.6167    0.7184        60
       9     0.7037    0.9500    0.8085        60

accuracy                         0.8067       569

macro avg 0.8139 0.8044 0.8018 569 weighted avg 0.8148 0.8067 0.8033 569 | | 1.0577 | 7.46 | 1000 | 1.0399 | 0.8155 | [[37, 3, 0, 0, 1, 1, 1, 2, 4, 3], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [1, 0, 38, 4, 1, 4, 0, 3, 0, 0], [3, 1, 0, 40, 0, 3, 0, 1, 0, 7], [1, 1, 2, 1, 50, 0, 0, 0, 0, 1], [0, 0, 6, 1, 1, 45, 1, 0, 0, 2], [3, 0, 0, 2, 1, 0, 53, 0, 2, 2], [0, 0, 3, 1, 0, 0, 0, 52, 0, 0], [3, 9, 0, 0, 0, 0, 9, 1, 38, 0], [0, 0, 0, 4, 0, 1, 0, 1, 0, 54]] | precision recall f1-score support

       0     0.7708    0.7115    0.7400        52
       1     0.8028    0.9500    0.8702        60
       2     0.7755    0.7451    0.7600        51
       3     0.7547    0.7273    0.7407        55
       4     0.9259    0.8929    0.9091        56
       5     0.8333    0.8036    0.8182        56
       6     0.8154    0.8413    0.8281        63
       7     0.8667    0.9286    0.8966        56
       8     0.8444    0.6333    0.7238        60
       9     0.7714    0.9000    0.8308        60

accuracy                         0.8155       569

macro avg 0.8161 0.8134 0.8117 569 weighted avg 0.8167 0.8155 0.8130 569 | | 0.9935 | 8.21 | 1100 | 1.0205 | 0.8190 | [[38, 4, 0, 0, 1, 1, 0, 2, 3, 3], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [1, 0, 38, 2, 1, 3, 0, 6, 0, 0], [3, 1, 0, 38, 0, 3, 0, 1, 0, 9], [1, 1, 2, 1, 50, 0, 0, 0, 0, 1], [0, 0, 7, 1, 2, 44, 0, 0, 0, 2], [3, 0, 0, 2, 1, 0, 54, 0, 2, 1], [0, 0, 2, 1, 0, 0, 0, 53, 0, 0], [2, 10, 0, 0, 0, 0, 9, 1, 38, 0], [0, 0, 0, 2, 0, 1, 0, 1, 0, 56]] | precision recall f1-score support

       0     0.7917    0.7308    0.7600        52
       1     0.7808    0.9500    0.8571        60
       2     0.7755    0.7451    0.7600        51
       3     0.8085    0.6909    0.7451        55
       4     0.9091    0.8929    0.9009        56
       5     0.8462    0.7857    0.8148        56
       6     0.8438    0.8571    0.8504        63
       7     0.8281    0.9464    0.8833        56
       8     0.8636    0.6333    0.7308        60
       9     0.7671    0.9333    0.8421        60

accuracy                         0.8190       569

macro avg 0.8214 0.8166 0.8145 569 weighted avg 0.8220 0.8190 0.8158 569 | | 1.1058 | 8.96 | 1200 | 1.0022 | 0.8225 | [[38, 3, 0, 0, 2, 1, 1, 1, 3, 3], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [1, 0, 37, 2, 1, 5, 0, 5, 0, 0], [4, 1, 0, 39, 0, 3, 0, 0, 0, 8], [1, 1, 2, 1, 50, 0, 0, 0, 0, 1], [0, 0, 6, 1, 1, 45, 1, 0, 0, 2], [3, 0, 0, 1, 1, 0, 55, 0, 2, 1], [0, 0, 3, 1, 0, 0, 0, 52, 0, 0], [3, 9, 0, 0, 0, 0, 9, 0, 39, 0], [0, 0, 0, 2, 0, 1, 0, 1, 0, 56]] | precision recall f1-score support

       0     0.7600    0.7308    0.7451        52
       1     0.8028    0.9500    0.8702        60
       2     0.7708    0.7255    0.7475        51
       3     0.8298    0.7091    0.7647        55
       4     0.9091    0.8929    0.9009        56
       5     0.8182    0.8036    0.8108        56
       6     0.8209    0.8730    0.8462        63
       7     0.8814    0.9286    0.9043        56
       8     0.8667    0.6500    0.7429        60
       9     0.7778    0.9333    0.8485        60

accuracy                         0.8225       569

macro avg 0.8237 0.8197 0.8181 569 weighted avg 0.8244 0.8225 0.8197 569 | | 1.0422 | 9.7 | 1300 | 0.9962 | 0.8243 | [[39, 3, 0, 0, 2, 1, 0, 1, 3, 3], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [1, 0, 38, 2, 1, 4, 0, 5, 0, 0], [4, 1, 0, 39, 0, 3, 0, 0, 0, 8], [1, 1, 2, 1, 50, 0, 0, 0, 0, 1], [0, 0, 7, 1, 1, 44, 1, 0, 0, 2], [3, 0, 0, 1, 1, 0, 55, 0, 2, 1], [0, 0, 3, 1, 0, 0, 0, 52, 0, 0], [2, 9, 0, 0, 0, 0, 9, 1, 39, 0], [0, 0, 0, 2, 0, 1, 0, 1, 0, 56]] | precision recall f1-score support

       0     0.7800    0.7500    0.7647        52
       1     0.8028    0.9500    0.8702        60
       2     0.7600    0.7451    0.7525        51
       3     0.8298    0.7091    0.7647        55
       4     0.9091    0.8929    0.9009        56
       5     0.8302    0.7857    0.8073        56
       6     0.8333    0.8730    0.8527        63
       7     0.8667    0.9286    0.8966        56
       8     0.8667    0.6500    0.7429        60
       9     0.7778    0.9333    0.8485        60

accuracy                         0.8243       569

macro avg 0.8256 0.8218 0.8201 569 weighted avg 0.8264 0.8243 0.8216 569 |

Framework versions

  • Transformers 4.39.3
  • Pytorch 2.1.2
  • Datasets 2.18.0
  • Tokenizers 0.15.2
Downloads last month
105
Safetensors
Model size
85.8M params
Tensor type
F32
·

Finetuned from

Evaluation results