Edit model card

DistilBERT for MIX Dataset

This repository contains the code and pre-trained DistilBERT model for Multi-Industry Classification (MXC) Dataset. The model was fine-tuned for 4 epochs on a combination of AG News, CNN News, and 20 News Group (relabeled) datasets.

Model description

DistilBERT is a smaller and faster version of BERT, a state-of-the-art language model developed by Google. DistilBERT has the same architecture as BERT, but with fewer layers and fewer parameters. This makes it much faster and easier to deploy in production environments.

The pre-trained DistilBERT model was fine-tuned on the MIC dataset, which consists of news articles from three different industries: finance, technology, and sports. The task is to classify each article into one of these three categories.

Requirements

  • Python 3.x
  • PyTorch 1.8 or later
  • Transformers 4.0 or later
  • Pandas
  • NumPy

Evaluation

The model was evaluated on two different datasets: AG News and CNN News. Here are the evaluation results:

AG News

  • Loss: 0.16675
  • Accuracy: 0.94815789
  • Precision: 0.9483508
  • Recall: 0.94815789
  • F1 score: 0.9482136

CNN News

  • Loss: 0.119858004
  • Accuracy: 0.963338028
  • Precision: 0.96379391
  • Recall: 0.963338028
  • F1 score: 0.963524631

Usage

You can use the pre-trained model to make predictions on new news articles. Here's an example:

from transformers import AutoTokenizer, DistilBertForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("AyoubChLin/distilbert_ag_cnn")
model = DistilBertForSequenceClassification.from_pretrained("AyoubChLin/distilbert_ag_cnn")

# Replace "input_text" with the actual news article
input_text = "The stock market had a strong performance today, driven by strong earnings reports from technology companies."

# Tokenize the input text
inputs = tokenizer(input_text, return_tensors='pt')

# Make a prediction
outputs = model(**inputs)
predicted_label = torch.argmax(outputs.logits, dim=1)

# Print the predicted label
print(predicted_label)

Credits

Downloads last month
2

Datasets used to train AyoubChLin/distilbert_ag_cnn