|
|
|
# t5-small Quantized Model for Text Summarization on Reddit-TIFU dataset |
|
|
|
This repository hosts a quantized version of the t5-small model, fine-tuned for text summarization using the Reddit-TIFU dataset. The model has been optimized using FP16 quantization for efficient deployment without significant accuracy loss. |
|
|
|
## Model Details |
|
|
|
- **Model Architecture:** t5-small(short version) |
|
- **Task:** Text generation |
|
- **Dataset:** Reddit-TIFU (Hugging Face Datasets) |
|
- **Quantization:** Float16 |
|
- **Fine-tuning Framework:** Hugging Face Transformers |
|
|
|
--- |
|
|
|
## Installation |
|
|
|
```bash |
|
pip install datasets transformers rouge-score evaluate |
|
``` |
|
|
|
--- |
|
|
|
## Loading the Model |
|
|
|
```python |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
import torch |
|
|
|
# Load tokenizer and model |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
model_name = "t5-small" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) |
|
|
|
# Define test sentences |
|
new_text = """ |
|
Today I was late to my morning meeting because I spilled coffee all over my laptop. |
|
Then I realized my backup laptop was also out of battery. |
|
Eventually joined from my phone, only to find out the meeting was cancelled. |
|
""" |
|
|
|
# Generate |
|
def generate_summary(text): |
|
inputs = tokenizer( |
|
text, |
|
return_tensors="pt", |
|
max_length=512, |
|
truncation=True |
|
).to(device) |
|
|
|
summary_ids = model.generate( |
|
inputs["input_ids"], |
|
max_length=100, |
|
min_length=5, |
|
num_beams=4, |
|
length_penalty=2.0, |
|
early_stopping=True |
|
) |
|
|
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
return summary |
|
``` |
|
|
|
--- |
|
|
|
## Performance Metrics |
|
|
|
- **Rouge1:** 19.590 |
|
- **Rouge2:** 4.270 |
|
- **Rougel:** 16.390 |
|
- **Rougelsum:** 16.800 |
|
|
|
--- |
|
|
|
## Fine-Tuning Details |
|
|
|
### Dataset |
|
|
|
The dataset is sourced from Hugging Faceβs `Reddit-TIFU` dataset. It contains 79,000 reddit post and their summaries. |
|
The original training and testing sets were merged, shuffled, and re-split using an 90/10 ratio. |
|
|
|
### Training |
|
|
|
- **Epochs:** 3 |
|
- **Batch size:** 8 |
|
- **Learning rate:** 2e-5 |
|
- **Evaluation strategy:** `epoch` |
|
|
|
--- |
|
|
|
## Quantization |
|
|
|
Post-training quantization was applied using PyTorchβs `half()` precision (FP16) to reduce model size and inference time. |
|
|
|
--- |
|
|
|
## Repository Structure |
|
|
|
```python |
|
. |
|
βββ quantized-model/ # Contains the quantized model files |
|
β βββ config.json |
|
β βββ model.safetensors |
|
β βββ tokenizer_config.json |
|
β βββ spiece.model |
|
β βββ special_tokens_map.json |
|
β βββ generation_config.jason |
|
β βββ tokenizer.json |
|
|
|
βββ README.md # Model documentation |
|
``` |
|
|
|
--- |
|
|
|
## Limitations |
|
|
|
- The model is trained specifically for text summarization on reddit posts |
|
- FP16 quantization may result in slight numerical instability in edge cases. |
|
|
|
|
|
--- |
|
|
|
## Contributing |
|
|
|
Feel free to open issues or submit pull requests to improve the model or documentation. |
|
|