--- library_name: sklearn tags: - sklearn - skops - text-classification --- # Model description This is a neural net classifier and distilbert model chained with sklearn Pipeline trained on 20 news groups dataset. ## Intended uses & limitations This model is trained for a tutorial and is not ready to be used in production. ## Training Procedure ### Hyperparameters The model is trained with below hyperparameters.
Click to expand | Hyperparameter | Value | |------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------| | memory | | | steps | [('tokenizer', HuggingfacePretrainedTokenizer(tokenizer='distilbert-base-uncased')), ('net', [initialized]( module_=BertModule( (bert): DistilBertForSequenceClassification( (distilbert): DistilBertModel( (embeddings): Embeddings( (word_embeddings): Embedding(30522, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (transformer): Transformer( (layer): ModuleList( (0): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (1): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (2): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (3): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (4): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (5): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) ) ) ) (pre_classifier): Linear(in_features=768, out_features=768, bias=True) (classifier): Linear(in_features=768, out_features=20, bias=True) (dropout): Dropout(p=0.2, inplace=False) ) ), ))] | | verbose | False | | tokenizer | HuggingfacePretrainedTokenizer(tokenizer='distilbert-base-uncased') | | net | [initialized]( module_=BertModule( (bert): DistilBertForSequenceClassification( (distilbert): DistilBertModel( (embeddings): Embeddings( (word_embeddings): Embedding(30522, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (transformer): Transformer( (layer): ModuleList( (0): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (1): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (2): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (3): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (4): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (5): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) ) ) ) (pre_classifier): Linear(in_features=768, out_features=768, bias=True) (classifier): Linear(in_features=768, out_features=20, bias=True) (dropout): Dropout(p=0.2, inplace=False) ) ), ) | | tokenizer__max_length | 256 | | tokenizer__return_attention_mask | True | | tokenizer__return_length | False | | tokenizer__return_tensors | pt | | tokenizer__return_token_type_ids | False | | tokenizer__tokenizer | distilbert-base-uncased | | tokenizer__train | False | | tokenizer__verbose | 0 | | tokenizer__vocab_size | | | net__module | | | net__criterion | | | net__optimizer | | | net__lr | 5e-05 | | net__max_epochs | 3 | | net__batch_size | 8 | | net__iterator_train | | | net__iterator_valid | | | net__dataset | | | net__train_split | | | net__callbacks | [, ] | | net__predict_nonlinearity | auto | | net__warm_start | False | | net__verbose | 1 | | net__device | cuda | | net___params_to_validate | {'module__num_labels', 'module__name', 'iterator_train__shuffle'} | | net__module__name | distilbert-base-uncased | | net__module__num_labels | 20 | | net__iterator_train__shuffle | True | | net__classes | | | net__callbacks__epoch_timer | | | net__callbacks__train_loss | | | net__callbacks__train_loss__name | train_loss | | net__callbacks__train_loss__lower_is_better | True | | net__callbacks__train_loss__on_train | True | | net__callbacks__valid_loss | | | net__callbacks__valid_loss__name | valid_loss | | net__callbacks__valid_loss__lower_is_better | True | | net__callbacks__valid_loss__on_train | False | | net__callbacks__valid_acc | | | net__callbacks__valid_acc__scoring | accuracy | | net__callbacks__valid_acc__lower_is_better | False | | net__callbacks__valid_acc__on_train | False | | net__callbacks__valid_acc__name | valid_acc | | net__callbacks__valid_acc__target_extractor | | | net__callbacks__valid_acc__use_caching | True | | net__callbacks__LRScheduler | | | net__callbacks__LRScheduler__policy | | | net__callbacks__LRScheduler__monitor | train_loss | | net__callbacks__LRScheduler__event_name | event_lr | | net__callbacks__LRScheduler__step_every | batch | | net__callbacks__LRScheduler__lr_lambda | | | net__callbacks__ProgressBar | | | net__callbacks__ProgressBar__batches_per_epoch | auto | | net__callbacks__ProgressBar__detect_notebook | True | | net__callbacks__ProgressBar__postfix_keys | ['train_loss', 'valid_loss'] | | net__callbacks__print_log | | | net__callbacks__print_log__keys_ignored | | | net__callbacks__print_log__sink | | | net__callbacks__print_log__tablefmt | simple | | net__callbacks__print_log__floatfmt | .4f | | net__callbacks__print_log__stralign | right |
### Model Plot The model plot is below.
Pipeline(steps=[('tokenizer',HuggingfacePretrainedTokenizer(tokenizer='distilbert-base-uncased')),('net',<class 'skorch.classifier.NeuralNetClassifier'>[initialized](module_=BertModule((bert): DistilBertForSequenceClassification((distilbert): DistilBertModel((embeddings): Embeddings((word_embeddings): Embedding(30522, 768, padding_idx=0)(position_embeddin...(lin1): Linear(in_features=768, out_features=3072, bias=True)(lin2): Linear(in_features=3072, out_features=768, bias=True)(activation): GELUActivation())(output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)))))(pre_classifier): Linear(in_features=768, out_features=768, bias=True)(classifier): Linear(in_features=768, out_features=20, bias=True)(dropout): Dropout(p=0.2, inplace=False))),
))])
Please rerun this cell to show the HTML repr or trust the notebook.
## Evaluation Results You can find the details about evaluation process and the evaluation results. | Metric | Value | |----------|---------| | accuracy | 0.90562 | | f1 score | 0.90562 | # How to Get Started with the Model Use the code below to get started with the model.
Click to expand ```python [More Information Needed] ```
# Additional Content ## Confusion matrix ![Confusion matrix](confusion_matrix.png) ## Classification Report
Click to expand | index | precision | recall | f1-score | support | |--------------------------|-------------|----------|------------|-----------| | alt.atheism | 0.927273 | 0.85 | 0.886957 | 120 | | comp.graphics | 0.85906 | 0.876712 | 0.867797 | 146 | | comp.os.ms-windows.misc | 0.893617 | 0.851351 | 0.871972 | 148 | | comp.sys.ibm.pc.hardware | 0.666667 | 0.837838 | 0.742515 | 148 | | comp.sys.mac.hardware | 0.901515 | 0.826389 | 0.862319 | 144 | | comp.windows.x | 0.923077 | 0.891892 | 0.907216 | 148 | | misc.forsale | 0.875862 | 0.869863 | 0.872852 | 146 | | rec.autos | 0.893082 | 0.95302 | 0.922078 | 149 | | rec.motorcycles | 0.937931 | 0.906667 | 0.922034 | 150 | | rec.sport.baseball | 0.954248 | 0.979866 | 0.966887 | 149 | | rec.sport.hockey | 0.979866 | 0.973333 | 0.976589 | 150 | | sci.crypt | 0.993103 | 0.966443 | 0.979592 | 149 | | sci.electronics | 0.869565 | 0.810811 | 0.839161 | 148 | | sci.med | 0.973154 | 0.973154 | 0.973154 | 149 | | sci.space | 0.973333 | 0.986486 | 0.979866 | 148 | | soc.religion.christian | 0.927152 | 0.933333 | 0.930233 | 150 | | talk.politics.guns | 0.961538 | 0.919118 | 0.93985 | 136 | | talk.politics.mideast | 0.978571 | 0.971631 | 0.975089 | 141 | | talk.politics.misc | 0.925234 | 0.853448 | 0.887892 | 116 | | talk.religion.misc | 0.728972 | 0.829787 | 0.776119 | 94 | | macro avg | 0.907141 | 0.903057 | 0.904009 | 2829 | | weighted avg | 0.909947 | 0.90562 | 0.906742 | 2829 |