DistilBERT for MIX Dataset
This repository contains the code and pre-trained DistilBERT model for Multi-Industry Classification (MXC) Dataset. The model was fine-tuned for 4 epochs on a combination of AG News, CNN News, and 20 News Group (relabeled) datasets.
Model description
DistilBERT is a smaller and faster version of BERT, a state-of-the-art language model developed by Google. DistilBERT has the same architecture as BERT, but with fewer layers and fewer parameters. This makes it much faster and easier to deploy in production environments.
The pre-trained DistilBERT model was fine-tuned on the MIC dataset, which consists of news articles from three different industries: finance, technology, and sports. The task is to classify each article into one of these three categories.
- Developed by: CHERGUELAINE Ayoub
- Shared by : HuggingFace
- Model type: Language model
- Language(s) (NLP): en
- Finetuned from model : distilbert-base-uncased
Requirements
- Python 3.x
- PyTorch 1.8 or later
- Transformers 4.0 or later
- Pandas
- NumPy
Evaluation
The model was evaluated on two different datasets: AG News and CNN News. Here are the evaluation results:
AG News
- Loss: 0.16675
- Accuracy: 0.94815789
- Precision: 0.9483508
- Recall: 0.94815789
- F1 score: 0.9482136
CNN News
- Loss: 0.119858004
- Accuracy: 0.963338028
- Precision: 0.96379391
- Recall: 0.963338028
- F1 score: 0.963524631
Usage
You can use the pre-trained model to make predictions on new news articles. Here's an example:
from transformers import AutoTokenizer, DistilBertForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("AyoubChLin/distilbert_ag_cnn")
model = DistilBertForSequenceClassification.from_pretrained("AyoubChLin/distilbert_ag_cnn")
# Replace "input_text" with the actual news article
input_text = "The stock market had a strong performance today, driven by strong earnings reports from technology companies."
# Tokenize the input text
inputs = tokenizer(input_text, return_tensors='pt')
# Make a prediction
outputs = model(**inputs)
predicted_label = torch.argmax(outputs.logits, dim=1)
# Print the predicted label
print(predicted_label)
Credits
- The MIX dataset was collected and annotated by CHERGUELAINE Ayoub
- Downloads last month
- 1