AyoubChLin
commited on
Commit
•
7722532
1
Parent(s):
d53c217
Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,87 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
datasets:
|
4 |
+
- ag_news
|
5 |
+
- AyoubChLin/CNN_News_Articles_2011-2022
|
6 |
+
language:
|
7 |
+
- en
|
8 |
+
metrics:
|
9 |
+
- accuracy
|
10 |
+
pipeline_tag: text-classification
|
11 |
---
|
12 |
+
|
13 |
+
# DistilBERT for MIX Dataset
|
14 |
+
|
15 |
+
This repository contains the code and pre-trained DistilBERT model for Multi-Industry Classification (MXC) Dataset. The model was fine-tuned for 4 epochs on a combination of AG News, CNN News, and 20 News Group (relabeled) datasets.
|
16 |
+
|
17 |
+
## Model description
|
18 |
+
|
19 |
+
DistilBERT is a smaller and faster version of BERT, a state-of-the-art language model developed by Google. DistilBERT has the same architecture as BERT, but with fewer layers and fewer parameters. This makes it much faster and easier to deploy in production environments.
|
20 |
+
|
21 |
+
The pre-trained DistilBERT model was fine-tuned on the MIC dataset, which consists of news articles from three different industries: finance, technology, and sports. The task is to classify each article into one of these three categories.
|
22 |
+
|
23 |
+
<!-- Provide a longer summary of what this model is. -->
|
24 |
+
|
25 |
+
|
26 |
+
- **Developed by:** [CHERGUELAINE Ayoub](https://www.linkedin.com/in/ayoub-cherguelaine/)
|
27 |
+
- **Shared by :** HuggingFace
|
28 |
+
- **Model type:** Language model
|
29 |
+
- **Language(s) (NLP):** en
|
30 |
+
- **Finetuned from model :** [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased)
|
31 |
+
|
32 |
+
|
33 |
+
## Requirements
|
34 |
+
|
35 |
+
- Python 3.x
|
36 |
+
- PyTorch 1.8 or later
|
37 |
+
- Transformers 4.0 or later
|
38 |
+
- Pandas
|
39 |
+
- NumPy
|
40 |
+
|
41 |
+
## Evaluation
|
42 |
+
|
43 |
+
The model was evaluated on two different datasets: AG News and CNN News. Here are the evaluation results:
|
44 |
+
|
45 |
+
### AG News
|
46 |
+
|
47 |
+
- Loss: 0.16675
|
48 |
+
- Accuracy: 0.94815789
|
49 |
+
- Precision: 0.9483508
|
50 |
+
- Recall: 0.94815789
|
51 |
+
- F1 score: 0.9482136
|
52 |
+
|
53 |
+
### CNN News
|
54 |
+
|
55 |
+
- Loss: 0.119858004
|
56 |
+
- Accuracy: 0.963338028
|
57 |
+
- Precision: 0.96379391
|
58 |
+
- Recall: 0.963338028
|
59 |
+
- F1 score: 0.963524631
|
60 |
+
|
61 |
+
## Usage
|
62 |
+
|
63 |
+
You can use the pre-trained model to make predictions on new news articles. Here's an example:
|
64 |
+
|
65 |
+
```python
|
66 |
+
from transformers import AutoTokenizer, DistilBertForSequenceClassification
|
67 |
+
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained("AyoubChLin/distilbert_ag_cnn")
|
69 |
+
model = DistilBertForSequenceClassification.from_pretrained("AyoubChLin/distilbert_ag_cnn")
|
70 |
+
|
71 |
+
# Replace "input_text" with the actual news article
|
72 |
+
input_text = "The stock market had a strong performance today, driven by strong earnings reports from technology companies."
|
73 |
+
|
74 |
+
# Tokenize the input text
|
75 |
+
inputs = tokenizer(input_text, return_tensors='pt')
|
76 |
+
|
77 |
+
# Make a prediction
|
78 |
+
outputs = model(**inputs)
|
79 |
+
predicted_label = torch.argmax(outputs.logits, dim=1)
|
80 |
+
|
81 |
+
# Print the predicted label
|
82 |
+
print(predicted_label)
|
83 |
+
```
|
84 |
+
|
85 |
+
## Credits
|
86 |
+
|
87 |
+
- The MIX dataset was collected and annotated by [CHERGUELAINE Ayoub](https://www.linkedin.com/in/ayoub-cherguelaine/)
|