Edit model card

Action_agent

This model is a fine-tuned version of google/vit-base-patch16-224-in21k on the agent_action_class dataset. It achieves the following results on the evaluation set:

  • Loss: 0.9962

  • Accuracy: 0.8243

  • Confusion Matrix: [[39, 3, 0, 0, 2, 1, 0, 1, 3, 3], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [1, 0, 38, 2, 1, 4, 0, 5, 0, 0], [4, 1, 0, 39, 0, 3, 0, 0, 0, 8], [1, 1, 2, 1, 50, 0, 0, 0, 0, 1], [0, 0, 7, 1, 1, 44, 1, 0, 0, 2], [3, 0, 0, 1, 1, 0, 55, 0, 2, 1], [0, 0, 3, 1, 0, 0, 0, 52, 0, 0], [2, 9, 0, 0, 0, 0, 9, 1, 39, 0], [0, 0, 0, 2, 0, 1, 0, 1, 0, 56]]

  • Classification Report: precision recall f1-score support

         0     0.7800    0.7500    0.7647        52
         1     0.8028    0.9500    0.8702        60
         2     0.7600    0.7451    0.7525        51
         3     0.8298    0.7091    0.7647        55
         4     0.9091    0.8929    0.9009        56
         5     0.8302    0.7857    0.8073        56
         6     0.8333    0.8730    0.8527        63
         7     0.8667    0.9286    0.8966        56
         8     0.8667    0.6500    0.7429        60
         9     0.7778    0.9333    0.8485        60
    

    accuracy 0.8243 569 macro avg 0.8256 0.8218 0.8201 569

weighted avg 0.8264 0.8243 0.8216 569

Model description

More information needed

Intended uses & limitations

More information needed

Training and evaluation data

More information needed

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 1e-05
  • train_batch_size: 32
  • eval_batch_size: 8
  • seed: 42
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • num_epochs: 10

Training results

Training Loss Epoch Step Validation Loss Accuracy Confusion Matrix Classification Report
2.1982 0.75 100 2.1583 0.4851 [[2, 3, 2, 1, 3, 1, 7, 15, 10, 8], [1, 52, 0, 0, 2, 0, 0, 2, 2, 1], [1, 0, 15, 0, 5, 0, 3, 23, 3, 1], [2, 1, 8, 12, 5, 0, 6, 6, 1, 14], [0, 2, 9, 1, 30, 2, 2, 3, 2, 5], [0, 2, 6, 2, 5, 16, 2, 16, 4, 3], [0, 7, 0, 1, 5, 2, 27, 1, 12, 8], [0, 0, 1, 0, 0, 0, 1, 54, 0, 0], [0, 11, 1, 0, 3, 2, 5, 7, 31, 0], [0, 3, 4, 1, 4, 1, 1, 6, 3, 37]] precision recall f1-score support
       0     0.3333    0.0385    0.0690        52
       1     0.6420    0.8667    0.7376        60
       2     0.3261    0.2941    0.3093        51
       3     0.6667    0.2182    0.3288        55
       4     0.4839    0.5357    0.5085        56
       5     0.6667    0.2857    0.4000        56
       6     0.5000    0.4286    0.4615        63
       7     0.4060    0.9643    0.5714        56
       8     0.4559    0.5167    0.4844        60
       9     0.4805    0.6167    0.5401        60

accuracy                         0.4851       569

macro avg 0.4961 0.4765 0.4411 569 weighted avg 0.4991 0.4851 0.4484 569 | | 1.988 | 1.49 | 200 | 1.9350 | 0.6257 | [[11, 6, 2, 0, 7, 1, 3, 10, 7, 5], [0, 58, 0, 0, 1, 0, 0, 0, 1, 0], [1, 1, 19, 0, 4, 1, 1, 24, 0, 0], [1, 1, 5, 16, 3, 0, 6, 7, 0, 16], [1, 1, 1, 0, 50, 0, 2, 0, 0, 1], [1, 0, 11, 0, 6, 25, 0, 11, 0, 2], [2, 8, 1, 1, 3, 1, 38, 2, 5, 2], [0, 0, 1, 0, 0, 0, 0, 55, 0, 0], [1, 12, 0, 0, 1, 1, 5, 6, 34, 0], [1, 0, 2, 3, 2, 0, 0, 2, 0, 50]] | precision recall f1-score support

       0     0.5789    0.2115    0.3099        52
       1     0.6667    0.9667    0.7891        60
       2     0.4524    0.3725    0.4086        51
       3     0.8000    0.2909    0.4267        55
       4     0.6494    0.8929    0.7519        56
       5     0.8621    0.4464    0.5882        56
       6     0.6909    0.6032    0.6441        63
       7     0.4701    0.9821    0.6358        56
       8     0.7234    0.5667    0.6355        60
       9     0.6579    0.8333    0.7353        60

accuracy                         0.6257       569

macro avg 0.6552 0.6166 0.5925 569 weighted avg 0.6583 0.6257 0.5997 569 | | 1.7347 | 2.24 | 300 | 1.6937 | 0.7223 | [[28, 4, 2, 1, 4, 1, 1, 1, 6, 4], [0, 58, 0, 0, 0, 0, 1, 0, 1, 0], [3, 0, 28, 0, 1, 1, 1, 16, 0, 1], [2, 2, 2, 29, 1, 0, 2, 2, 0, 15], [2, 1, 1, 0, 49, 0, 1, 0, 0, 2], [1, 0, 6, 0, 3, 35, 1, 8, 0, 2], [4, 5, 1, 1, 1, 0, 38, 1, 10, 2], [0, 0, 0, 0, 0, 0, 0, 56, 0, 0], [6, 11, 0, 0, 1, 0, 5, 2, 35, 0], [0, 0, 2, 2, 0, 0, 0, 1, 0, 55]] | precision recall f1-score support

       0     0.6087    0.5385    0.5714        52
       1     0.7160    0.9667    0.8227        60
       2     0.6667    0.5490    0.6022        51
       3     0.8788    0.5273    0.6591        55
       4     0.8167    0.8750    0.8448        56
       5     0.9459    0.6250    0.7527        56
       6     0.7600    0.6032    0.6726        63
       7     0.6437    1.0000    0.7832        56
       8     0.6731    0.5833    0.6250        60
       9     0.6790    0.9167    0.7801        60

accuracy                         0.7223       569

macro avg 0.7389 0.7185 0.7114 569 weighted avg 0.7394 0.7223 0.7136 569 | | 1.5713 | 2.99 | 400 | 1.4857 | 0.7434 | [[26, 6, 2, 1, 5, 1, 0, 2, 5, 4], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [2, 0, 29, 1, 2, 2, 2, 13, 0, 0], [3, 1, 4, 32, 1, 1, 0, 1, 0, 12], [1, 1, 1, 0, 49, 0, 1, 0, 0, 3], [1, 0, 6, 0, 4, 41, 0, 2, 0, 2], [3, 5, 1, 0, 1, 0, 42, 0, 8, 3], [0, 0, 0, 1, 0, 0, 0, 55, 0, 0], [4, 11, 0, 0, 0, 0, 8, 2, 35, 0], [0, 0, 2, 0, 0, 0, 0, 1, 0, 57]] | precision recall f1-score support

       0     0.6500    0.5000    0.5652        52
       1     0.7037    0.9500    0.8085        60
       2     0.6444    0.5686    0.6042        51
       3     0.9143    0.5818    0.7111        55
       4     0.7903    0.8750    0.8305        56
       5     0.9111    0.7321    0.8119        56
       6     0.7778    0.6667    0.7179        63
       7     0.7237    0.9821    0.8333        56
       8     0.7143    0.5833    0.6422        60
       9     0.6951    0.9500    0.8028        60

accuracy                         0.7434       569

macro avg 0.7525 0.7390 0.7328 569 weighted avg 0.7532 0.7434 0.7353 569 | | 1.3821 | 3.73 | 500 | 1.3477 | 0.7575 | [[30, 4, 0, 3, 4, 1, 0, 2, 4, 4], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [2, 0, 30, 4, 1, 2, 1, 10, 0, 1], [3, 2, 2, 27, 0, 1, 0, 2, 0, 18], [1, 1, 1, 0, 49, 0, 1, 0, 0, 3], [1, 0, 5, 0, 1, 44, 1, 1, 0, 3], [4, 0, 1, 1, 1, 0, 49, 0, 3, 4], [0, 0, 2, 1, 0, 0, 0, 53, 0, 0], [3, 11, 0, 0, 0, 0, 10, 2, 34, 0], [0, 0, 1, 0, 0, 0, 0, 1, 0, 58]] | precision recall f1-score support

       0     0.6818    0.5769    0.6250        52
       1     0.7600    0.9500    0.8444        60
       2     0.7143    0.5882    0.6452        51
       3     0.7500    0.4909    0.5934        55
       4     0.8750    0.8750    0.8750        56
       5     0.9167    0.7857    0.8462        56
       6     0.7778    0.7778    0.7778        63
       7     0.7465    0.9464    0.8346        56
       8     0.8095    0.5667    0.6667        60
       9     0.6304    0.9667    0.7632        60

accuracy                         0.7575       569

macro avg 0.7662 0.7524 0.7471 569 weighted avg 0.7667 0.7575 0.7498 569 | | 1.3065 | 4.48 | 600 | 1.2437 | 0.7856 | [[33, 4, 0, 1, 3, 1, 0, 2, 4, 4], [0, 56, 0, 0, 0, 0, 1, 0, 2, 1], [1, 0, 29, 5, 1, 2, 1, 12, 0, 0], [2, 1, 1, 36, 0, 3, 0, 2, 0, 10], [1, 1, 1, 1, 50, 0, 0, 0, 0, 2], [1, 0, 4, 1, 1, 42, 1, 4, 0, 2], [3, 0, 0, 0, 1, 0, 53, 0, 3, 3], [0, 0, 0, 1, 0, 0, 0, 55, 0, 0], [4, 9, 0, 0, 0, 0, 9, 1, 37, 0], [0, 0, 0, 2, 0, 1, 0, 1, 0, 56]] | precision recall f1-score support

       0     0.7333    0.6346    0.6804        52
       1     0.7887    0.9333    0.8550        60
       2     0.8286    0.5686    0.6744        51
       3     0.7660    0.6545    0.7059        55
       4     0.8929    0.8929    0.8929        56
       5     0.8571    0.7500    0.8000        56
       6     0.8154    0.8413    0.8281        63
       7     0.7143    0.9821    0.8271        56
       8     0.8043    0.6167    0.6981        60
       9     0.7179    0.9333    0.8116        60

accuracy                         0.7856       569

macro avg 0.7919 0.7807 0.7773 569 weighted avg 0.7918 0.7856 0.7799 569 | | 1.2329 | 5.22 | 700 | 1.1645 | 0.7909 | [[34, 4, 0, 1, 3, 1, 0, 1, 4, 4], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [1, 0, 33, 5, 1, 3, 1, 7, 0, 0], [3, 1, 1, 31, 1, 2, 0, 1, 0, 15], [1, 1, 1, 1, 50, 0, 0, 0, 0, 2], [1, 0, 7, 1, 2, 43, 0, 0, 0, 2], [2, 0, 0, 0, 1, 0, 56, 0, 1, 3], [0, 0, 2, 1, 0, 0, 0, 53, 0, 0], [2, 11, 0, 0, 0, 0, 10, 1, 36, 0], [0, 0, 0, 1, 0, 1, 0, 1, 0, 57]] | precision recall f1-score support

       0     0.7727    0.6538    0.7083        52
       1     0.7703    0.9500    0.8507        60
       2     0.7500    0.6471    0.6947        51
       3     0.7561    0.5636    0.6458        55
       4     0.8621    0.8929    0.8772        56
       5     0.8600    0.7679    0.8113        56
       6     0.8235    0.8889    0.8550        63
       7     0.8281    0.9464    0.8833        56
       8     0.8571    0.6000    0.7059        60
       9     0.6786    0.9500    0.7917        60

accuracy                         0.7909       569

macro avg 0.7959 0.7861 0.7824 569 weighted avg 0.7963 0.7909 0.7848 569 | | 1.1736 | 5.97 | 800 | 1.1159 | 0.7891 | [[35, 4, 0, 0, 2, 1, 1, 1, 4, 4], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [2, 0, 35, 2, 1, 3, 1, 7, 0, 0], [3, 1, 0, 34, 0, 3, 0, 1, 0, 13], [1, 1, 2, 1, 49, 0, 0, 0, 0, 2], [1, 0, 7, 1, 1, 43, 1, 0, 0, 2], [3, 0, 0, 0, 1, 0, 51, 0, 4, 4], [0, 0, 3, 1, 0, 0, 0, 52, 0, 0], [4, 10, 0, 0, 0, 0, 8, 1, 37, 0], [0, 0, 0, 3, 0, 0, 0, 1, 0, 56]] | precision recall f1-score support

       0     0.7143    0.6731    0.6931        52
       1     0.7808    0.9500    0.8571        60
       2     0.7447    0.6863    0.7143        51
       3     0.8095    0.6182    0.7010        55
       4     0.9074    0.8750    0.8909        56
       5     0.8600    0.7679    0.8113        56
       6     0.8095    0.8095    0.8095        63
       7     0.8254    0.9286    0.8739        56
       8     0.8043    0.6167    0.6981        60
       9     0.6829    0.9333    0.7887        60

accuracy                         0.7891       569

macro avg 0.7939 0.7858 0.7838 569 weighted avg 0.7942 0.7891 0.7855 569 | | 1.1396 | 6.72 | 900 | 1.0749 | 0.8067 | [[39, 3, 0, 0, 1, 1, 0, 2, 3, 3], [1, 56, 0, 0, 0, 0, 1, 0, 1, 1], [2, 0, 38, 1, 1, 3, 0, 6, 0, 0], [3, 1, 1, 33, 0, 3, 0, 1, 0, 13], [1, 1, 2, 1, 50, 0, 0, 0, 0, 1], [0, 0, 7, 1, 1, 44, 1, 0, 0, 2], [3, 0, 0, 0, 1, 0, 53, 0, 2, 4], [0, 0, 3, 1, 0, 0, 0, 52, 0, 0], [5, 9, 0, 0, 0, 0, 8, 1, 37, 0], [0, 0, 0, 1, 0, 1, 0, 1, 0, 57]] | precision recall f1-score support

       0     0.7222    0.7500    0.7358        52
       1     0.8000    0.9333    0.8615        60
       2     0.7451    0.7451    0.7451        51
       3     0.8684    0.6000    0.7097        55
       4     0.9259    0.8929    0.9091        56
       5     0.8462    0.7857    0.8148        56
       6     0.8413    0.8413    0.8413        63
       7     0.8254    0.9286    0.8739        56
       8     0.8605    0.6167    0.7184        60
       9     0.7037    0.9500    0.8085        60

accuracy                         0.8067       569

macro avg 0.8139 0.8044 0.8018 569 weighted avg 0.8148 0.8067 0.8033 569 | | 1.0577 | 7.46 | 1000 | 1.0399 | 0.8155 | [[37, 3, 0, 0, 1, 1, 1, 2, 4, 3], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [1, 0, 38, 4, 1, 4, 0, 3, 0, 0], [3, 1, 0, 40, 0, 3, 0, 1, 0, 7], [1, 1, 2, 1, 50, 0, 0, 0, 0, 1], [0, 0, 6, 1, 1, 45, 1, 0, 0, 2], [3, 0, 0, 2, 1, 0, 53, 0, 2, 2], [0, 0, 3, 1, 0, 0, 0, 52, 0, 0], [3, 9, 0, 0, 0, 0, 9, 1, 38, 0], [0, 0, 0, 4, 0, 1, 0, 1, 0, 54]] | precision recall f1-score support

       0     0.7708    0.7115    0.7400        52
       1     0.8028    0.9500    0.8702        60
       2     0.7755    0.7451    0.7600        51
       3     0.7547    0.7273    0.7407        55
       4     0.9259    0.8929    0.9091        56
       5     0.8333    0.8036    0.8182        56
       6     0.8154    0.8413    0.8281        63
       7     0.8667    0.9286    0.8966        56
       8     0.8444    0.6333    0.7238        60
       9     0.7714    0.9000    0.8308        60

accuracy                         0.8155       569

macro avg 0.8161 0.8134 0.8117 569 weighted avg 0.8167 0.8155 0.8130 569 | | 0.9935 | 8.21 | 1100 | 1.0205 | 0.8190 | [[38, 4, 0, 0, 1, 1, 0, 2, 3, 3], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [1, 0, 38, 2, 1, 3, 0, 6, 0, 0], [3, 1, 0, 38, 0, 3, 0, 1, 0, 9], [1, 1, 2, 1, 50, 0, 0, 0, 0, 1], [0, 0, 7, 1, 2, 44, 0, 0, 0, 2], [3, 0, 0, 2, 1, 0, 54, 0, 2, 1], [0, 0, 2, 1, 0, 0, 0, 53, 0, 0], [2, 10, 0, 0, 0, 0, 9, 1, 38, 0], [0, 0, 0, 2, 0, 1, 0, 1, 0, 56]] | precision recall f1-score support

       0     0.7917    0.7308    0.7600        52
       1     0.7808    0.9500    0.8571        60
       2     0.7755    0.7451    0.7600        51
       3     0.8085    0.6909    0.7451        55
       4     0.9091    0.8929    0.9009        56
       5     0.8462    0.7857    0.8148        56
       6     0.8438    0.8571    0.8504        63
       7     0.8281    0.9464    0.8833        56
       8     0.8636    0.6333    0.7308        60
       9     0.7671    0.9333    0.8421        60

accuracy                         0.8190       569

macro avg 0.8214 0.8166 0.8145 569 weighted avg 0.8220 0.8190 0.8158 569 | | 1.1058 | 8.96 | 1200 | 1.0022 | 0.8225 | [[38, 3, 0, 0, 2, 1, 1, 1, 3, 3], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [1, 0, 37, 2, 1, 5, 0, 5, 0, 0], [4, 1, 0, 39, 0, 3, 0, 0, 0, 8], [1, 1, 2, 1, 50, 0, 0, 0, 0, 1], [0, 0, 6, 1, 1, 45, 1, 0, 0, 2], [3, 0, 0, 1, 1, 0, 55, 0, 2, 1], [0, 0, 3, 1, 0, 0, 0, 52, 0, 0], [3, 9, 0, 0, 0, 0, 9, 0, 39, 0], [0, 0, 0, 2, 0, 1, 0, 1, 0, 56]] | precision recall f1-score support

       0     0.7600    0.7308    0.7451        52
       1     0.8028    0.9500    0.8702        60
       2     0.7708    0.7255    0.7475        51
       3     0.8298    0.7091    0.7647        55
       4     0.9091    0.8929    0.9009        56
       5     0.8182    0.8036    0.8108        56
       6     0.8209    0.8730    0.8462        63
       7     0.8814    0.9286    0.9043        56
       8     0.8667    0.6500    0.7429        60
       9     0.7778    0.9333    0.8485        60

accuracy                         0.8225       569

macro avg 0.8237 0.8197 0.8181 569 weighted avg 0.8244 0.8225 0.8197 569 | | 1.0422 | 9.7 | 1300 | 0.9962 | 0.8243 | [[39, 3, 0, 0, 2, 1, 0, 1, 3, 3], [0, 57, 0, 0, 0, 0, 1, 0, 1, 1], [1, 0, 38, 2, 1, 4, 0, 5, 0, 0], [4, 1, 0, 39, 0, 3, 0, 0, 0, 8], [1, 1, 2, 1, 50, 0, 0, 0, 0, 1], [0, 0, 7, 1, 1, 44, 1, 0, 0, 2], [3, 0, 0, 1, 1, 0, 55, 0, 2, 1], [0, 0, 3, 1, 0, 0, 0, 52, 0, 0], [2, 9, 0, 0, 0, 0, 9, 1, 39, 0], [0, 0, 0, 2, 0, 1, 0, 1, 0, 56]] | precision recall f1-score support

       0     0.7800    0.7500    0.7647        52
       1     0.8028    0.9500    0.8702        60
       2     0.7600    0.7451    0.7525        51
       3     0.8298    0.7091    0.7647        55
       4     0.9091    0.8929    0.9009        56
       5     0.8302    0.7857    0.8073        56
       6     0.8333    0.8730    0.8527        63
       7     0.8667    0.9286    0.8966        56
       8     0.8667    0.6500    0.7429        60
       9     0.7778    0.9333    0.8485        60

accuracy                         0.8243       569

macro avg 0.8256 0.8218 0.8201 569 weighted avg 0.8264 0.8243 0.8216 569 |

Framework versions

  • Transformers 4.39.3
  • Pytorch 2.1.2
  • Datasets 2.18.0
  • Tokenizers 0.15.2
Downloads last month
1
Safetensors
Model size
85.8M params
Tensor type
F32
·
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Finetuned from

Evaluation results