Edit model card

Model description

This is a neural net classifier and distilbert model chained with sklearn Pipeline trained on 20 news groups dataset.

Intended uses & limitations

This model is trained for a tutorial and is not ready to be used in production.

Training Procedure

Hyperparameters

The model is trained with below hyperparameters.

Click to expand
Hyperparameter Value
memory
steps [('tokenizer', HuggingfacePretrainedTokenizer(tokenizer='distilbert-base-uncased')), ('net', <class 'skorch.classifier.NeuralNetClassifier'>[initialized](
module_=BertModule(
(bert): DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_features=768, out_features=3072, bias=True)
            (lin2): Linear(in_features=3072, out_features=768, bias=True)
            (activation): GELUActivation()
          )
          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )
        (1): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_features=768, out_features=3072, bias=True)
            (lin2): Linear(in_features=3072, out_features=768, bias=True)
            (activation): GELUActivation()
          )
          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )
        (2): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_features=768, out_features=3072, bias=True)
            (lin2): Linear(in_features=3072, out_features=768, bias=True)
            (activation): GELUActivation()
          )
          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )
        (3): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_features=768, out_features=3072, bias=True)
            (lin2): Linear(in_features=3072, out_features=768, bias=True)
            (activation): GELUActivation()
          )
          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )
        (4): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_features=768, out_features=3072, bias=True)
            (lin2): Linear(in_features=3072, out_features=768, bias=True)
            (activation): GELUActivation()
          )
          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )
        (5): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_features=768, out_features=3072, bias=True)
            (lin2): Linear(in_features=3072, out_features=768, bias=True)
            (activation): GELUActivation()
          )
          (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )
      )
    )
  )
  (pre_classifier): Linear(in_features=768, out_features=768, bias=True)
  (classifier): Linear(in_features=768, out_features=20, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

), ))] | | verbose | False | | tokenizer | HuggingfacePretrainedTokenizer(tokenizer='distilbert-base-uncased') | | net | <class 'skorch.classifier.NeuralNetClassifier'>[initialized]( module_=BertModule( (bert): DistilBertForSequenceClassification( (distilbert): DistilBertModel( (embeddings): Embeddings( (word_embeddings): Embedding(30522, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (transformer): Transformer( (layer): ModuleList( (0): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (1): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (2): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (3): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (4): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (5): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) ) ) ) (pre_classifier): Linear(in_features=768, out_features=768, bias=True) (classifier): Linear(in_features=768, out_features=20, bias=True) (dropout): Dropout(p=0.2, inplace=False) ) ), ) | | tokenizer__max_length | 256 | | tokenizer__return_attention_mask | True | | tokenizer__return_length | False | | tokenizer__return_tensors | pt | | tokenizer__return_token_type_ids | False | | tokenizer__tokenizer | distilbert-base-uncased | | tokenizer__train | False | | tokenizer__verbose | 0 | | tokenizer__vocab_size | | | net__module | <class '__main__.BertModule'> | | net__criterion | <class 'torch.nn.modules.loss.CrossEntropyLoss'> | | net__optimizer | <class 'torch.optim.adamw.AdamW'> | | net__lr | 5e-05 | | net__max_epochs | 3 | | net__batch_size | 8 | | net__iterator_train | <class 'torch.utils.data.dataloader.DataLoader'> | | net__iterator_valid | <class 'torch.utils.data.dataloader.DataLoader'> | | net__dataset | <class 'skorch.dataset.Dataset'> | | net__train_split | <skorch.dataset.ValidSplit object at 0x7f9945e18c90> | | net__callbacks | [<skorch.callbacks.lr_scheduler.LRScheduler object at 0x7f9945da85d0>, <skorch.callbacks.logging.ProgressBar object at 0x7f9945da8250>] | | net__predict_nonlinearity | auto | | net__warm_start | False | | net__verbose | 1 | | net__device | cuda | | net___params_to_validate | {'module__num_labels', 'module__name', 'iterator_train__shuffle'} | | net__module__name | distilbert-base-uncased | | net__module__num_labels | 20 | | net__iterator_train__shuffle | True | | net__classes | | | net__callbacks__epoch_timer | <skorch.callbacks.logging.EpochTimer object at 0x7f993cb300d0> | | net__callbacks__train_loss | <skorch.callbacks.scoring.PassthroughScoring object at 0x7f993cb306d0> | | net__callbacks__train_loss__name | train_loss | | net__callbacks__train_loss__lower_is_better | True | | net__callbacks__train_loss__on_train | True | | net__callbacks__valid_loss | <skorch.callbacks.scoring.PassthroughScoring object at 0x7f993cb30ed0> | | net__callbacks__valid_loss__name | valid_loss | | net__callbacks__valid_loss__lower_is_better | True | | net__callbacks__valid_loss__on_train | False | | net__callbacks__valid_acc | <skorch.callbacks.scoring.EpochScoring object at 0x7f993cb30410> | | net__callbacks__valid_acc__scoring | accuracy | | net__callbacks__valid_acc__lower_is_better | False | | net__callbacks__valid_acc__on_train | False | | net__callbacks__valid_acc__name | valid_acc | | net__callbacks__valid_acc__target_extractor | <function to_numpy at 0x7f9945e46a70> | | net__callbacks__valid_acc__use_caching | True | | net__callbacks__LRScheduler | <skorch.callbacks.lr_scheduler.LRScheduler object at 0x7f9945da85d0> | | net__callbacks__LRScheduler__policy | <class 'torch.optim.lr_scheduler.LambdaLR'> | | net__callbacks__LRScheduler__monitor | train_loss | | net__callbacks__LRScheduler__event_name | event_lr | | net__callbacks__LRScheduler__step_every | batch | | net__callbacks__LRScheduler__lr_lambda | <function lr_schedule at 0x7f9945d9c440> | | net__callbacks__ProgressBar | <skorch.callbacks.logging.ProgressBar object at 0x7f9945da8250> | | net__callbacks__ProgressBar__batches_per_epoch | auto | | net__callbacks__ProgressBar__detect_notebook | True | | net__callbacks__ProgressBar__postfix_keys | ['train_loss', 'valid_loss'] | | net__callbacks__print_log | <skorch.callbacks.logging.PrintLog object at 0x7f993cb30dd0> | | net__callbacks__print_log__keys_ignored | | | net__callbacks__print_log__sink | | | net__callbacks__print_log__tablefmt | simple | | net__callbacks__print_log__floatfmt | .4f | | net__callbacks__print_log__stralign | right |

Model Plot

The model plot is below.

Pipeline(steps=[('tokenizer',HuggingfacePretrainedTokenizer(tokenizer='distilbert-base-uncased')),('net',<class 'skorch.classifier.NeuralNetClassifier'>[initialized](module_=BertModule((bert): DistilBertForSequenceClassification((distilbert): DistilBertModel((embeddings): Embeddings((word_embeddings): Embedding(30522, 768, padding_idx=0)(position_embeddin...(lin1): Linear(in_features=768, out_features=3072, bias=True)(lin2): Linear(in_features=3072, out_features=768, bias=True)(activation): GELUActivation())(output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)))))(pre_classifier): Linear(in_features=768, out_features=768, bias=True)(classifier): Linear(in_features=768, out_features=20, bias=True)(dropout): Dropout(p=0.2, inplace=False))),

))])

Please rerun this cell to show the HTML repr or trust the notebook.

Evaluation Results

You can find the details about evaluation process and the evaluation results.

Metric Value
accuracy 0.90562
f1 score 0.90562

How to Get Started with the Model

Use the code below to get started with the model.

Click to expand
[More Information Needed]

Additional Content

Confusion matrix

Confusion matrix

Classification Report

Click to expand
index precision recall f1-score support
alt.atheism 0.927273 0.85 0.886957 120
comp.graphics 0.85906 0.876712 0.867797 146
comp.os.ms-windows.misc 0.893617 0.851351 0.871972 148
comp.sys.ibm.pc.hardware 0.666667 0.837838 0.742515 148
comp.sys.mac.hardware 0.901515 0.826389 0.862319 144
comp.windows.x 0.923077 0.891892 0.907216 148
misc.forsale 0.875862 0.869863 0.872852 146
rec.autos 0.893082 0.95302 0.922078 149
rec.motorcycles 0.937931 0.906667 0.922034 150
rec.sport.baseball 0.954248 0.979866 0.966887 149
rec.sport.hockey 0.979866 0.973333 0.976589 150
sci.crypt 0.993103 0.966443 0.979592 149
sci.electronics 0.869565 0.810811 0.839161 148
sci.med 0.973154 0.973154 0.973154 149
sci.space 0.973333 0.986486 0.979866 148
soc.religion.christian 0.927152 0.933333 0.930233 150
talk.politics.guns 0.961538 0.919118 0.93985 136
talk.politics.mideast 0.978571 0.971631 0.975089 141
talk.politics.misc 0.925234 0.853448 0.887892 116
talk.religion.misc 0.728972 0.829787 0.776119 94
macro avg 0.907141 0.903057 0.904009 2829
weighted avg 0.909947 0.90562 0.906742 2829
Downloads last month
0