Instructions to use havvanur92/news-topic-classification-model with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Keras
How to use havvanur92/news-topic-classification-model with Keras:
# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.saving.load_model("hf://havvanur92/news-topic-classification-model") - Notebooks
- Google Colab
- Kaggle
π°RNN-News Topic Classification
The goal of this project is to build an RNN-based model that automatically classifies news articles from the AG News dataset into four predefined categories.
π Table of Contents
- Dependencies & Environment
- Dataset Overview
- Text Preprocessing
- ποΈ RNN Model Architecture
- Hyperparameter Search βοΈπ
- Model Evaluation π
Dependencies & Environment
π¦ Main Dependencies
| Library | Purpose |
|---|---|
| numpy, pandas | Data manipulation and preprocessing |
| matplotlib, seaborn | Visualizations (plots, ROC curves, confusion matrix) |
| tensorflow / keras | Building, training, and tuning the RNN model |
| keras-tuner | Hyperparameter optimization (Random Search) |
| scikit-learn | Evaluation metrics (classification report, ROC, confusion matrix) |
π§ Environment: Kaggle Notebooks using the GPU T4 (Python 3.x, TensorFlow 2.x)
π§ Installation
If running locally, install dependencies with:
pip install numpy pandas matplotlib seaborn scikit-learn tensorflow keras-tuner
Dataset Overview
The AG News dataset contains news articles labeled into four main categories: World, Sports, Business, and Science/Technology. Each sample consists of a title and a short description, providing concise textual information for classification tasks.
π Source: AG News β Academic news search engine ComeToMyHead
π Categories: 4 predefined classes
π Files: train.csv and test.csv
π Content: Class index, title, and description
π’ Total Samples: 127,600 (120,000 training + 7,600 test)
Text Preprocessing
The Title and Description fields were merged into a single text input and prepared for the RNN model using tokenization.
Process:
- Merged title + description.
- Applied a tokenizer with a 10,000-word vocabulary and an = <OOV> token for unseen words.
- Converted text into integer sequences.
- Applied padding and truncation (padding='post', truncating='post') to ensure uniform sequence length (maxlen = 100).
RNN Model Architecture
ποΈ The model uses a SimpleRNN structure to learn patterns from the text.
- Embedding Layer: Turns each word into a vector.
- SimpleRNN Layer: Processes the sequence step-by-step and learns the meaning of the text.
- Dropout: Helps reduce overfitting.
- Dense Output Layer: Softmax layer that predicts the correct news category.
The model was trained using the Adam optimizer, gradient clipping for stability, and categorical crossentropy for multi-class classification.
Hyperparameter Search
To improve model performance, Random Search (via Keras Tuner) was used to find the best hyperparameters. βοΈπ The search optimized:
- Embedding dimension
- Number of RNN units
- Dropout rate
- Gradient clipping value
- A total of 30 different model configurations were tested.
- Each configuration was trained once, and early stopping was applied to prevent unnecessary training.
- This process helped identify the most effective architecture before training the final model.
Best Hyperparameters π―
After the hyperparameter search, the tuner identified the best configuration, achieving a validation accuracy of 0.888.
Selected Hyperparameters:
- Embedding Dimension: 128
- RNN Units: 160
- Dropout Rate: 0.3
- Clipnorm: 1.0
Model Evaluation π
After selecting the best hyperparameters, the final model wasevaluated on the test set.
Test Performance:
Accuracy: 0.8915
Loss: 0.3542
AUC: 0.9756
These results show that the model generalizes well across the four news categories.
ROC Curves
ROC curves were plotted for each class using the predicted probabilities. They show how well the model distinguishes each category across different thresholds.
Multi Class Roc Curve - Source
Classification Report
| Class | Precision | Recall | F1-Score | Support |
|---|---|---|---|---|
| World | 0.91 | 0.89 | 0.90 | 1900 |
| Sports | 0.95 | 0.95 | 0.95 | 1900 |
| Business | 0.84 | 0.89 | 0.87 | 1900 |
| Science | 0.89 | 0.86 | 0.87 | 1900 |
| Overall Accuracy | β | β | 0.90 | 7600 |
The model performs strongest on Sports and World, with slightly lower performance on Business and Science, which typically contain more overlapping terminology.
Confusion Matrix
A confusion matrix was generated to visualize how well the model distinguishes between the four categories.
- Downloads last month
- -