Instructions to use hgLeo12/RoTRAG_roberta_routing_classifier with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use hgLeo12/RoTRAG_roberta_routing_classifier with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="hgLeo12/RoTRAG_roberta_routing_classifier")# Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained("hgLeo12/RoTRAG_roberta_routing_classifier") model = AutoModelForSequenceClassification.from_pretrained("hgLeo12/RoTRAG_roberta_routing_classifier") - Notebooks
- Google Colab
- Kaggle
RoTRAG RoBERTa Routing Classifier
This model is the routing classifier used in RoTRAG: Retrieval-Augmented Rule-of-Thumb Reasoning for Conversation Harm Detection.
The classifier decides whether a previously generated Rule of Thumb (RoT) can be carried over to the current dialogue turn or whether RoTRAG should generate a new RoT for the current context.
Label Space
| Label | Meaning |
|---|---|
0 |
The previous RoT is no longer sufficient; generate a new RoT. |
1 |
The previous RoT can be retained for the current turn. |
Intended Use
This checkpoint is intended for the router step in the RoTRAG pipeline. Given the previous RoT information and the current context-response pair, it predicts whether the normative core of the previous RoT still applies.
It is not a standalone harm detector. The classifier only decides RoT carry-over versus RoT regeneration; downstream safety or prosocial labels are produced by separate RoTRAG reasoning and prediction steps.
Model Details
- Architecture:
RobertaForSequenceClassification - Base model:
roberta-large - Task: binary sequence classification
- Input type: previous RoT(s), current dialogue context, and current response
- Output: binary routing label (
0or1)
Training Data
The routing supervision data was constructed from ProsocialDialog train/validation instances. The annotation process first used 1,000 human-labeled routing examples, annotated by 10 expert annotators. The final gold label for each example was determined by hard voting.
The resulting annotation rule was distilled into a prompt and validated against the human-labeled subset, reaching 0.98 accuracy relative to the human gold labels. The validated prompt was then used to expand labels over the remaining training data.
Expanded label counts reported for classifier training:
| Split | Label 0 |
Label 1 |
|---|---|---|
| Train | 14,528 | 46,206 |
| Validation | 2,489 | 8,147 |
Training Setup
The classifier was fine-tuned with the following setup:
| Hyperparameter | Value |
|---|---|
| Base model | roberta-large |
| Epochs | 3 |
| Learning rate | 2e-5 |
| Scheduler | Linear, no warm-up |
| Optimizer | AdamW |
| Weight decay | 0.01 |
| Train batch size | 8 |
| Gradient accumulation | 2 |
| Effective batch size | 16 |
| Evaluation batch size | 16 |
| Max sequence length | 256 |
| Padding | Dynamic |
| Mixed precision | FP16 |
| Best checkpoint criterion | Macro F1 |
| Random seed | 42 |
Evaluation
On 10,636 validation routing instances, the classifier produced the following confusion matrix counts:
| Outcome | Count | Share |
|---|---|---|
| True positive | 7,461 | 70.215% |
| True negative | 1,743 | 16.403% |
| False positive | 746 | 7.021% |
| False negative | 676 | 6.362% |
The reported behavior is strongest when the current turn clearly preserves or clearly departs from the previous RoT. Remaining errors are often borderline cases where the current response is partially related to the previous RoT but shifts specificity, framing, or normative emphasis.
Example Usage
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
model_id = "hgLeo12/roberta_classifier"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id)
model.eval()
previous_rot = "You should not manipulate others."
current_context = "She gets what she gets. Her bad habits are getting out of control."
current_response = (
"You would have more luck if you just talked to her and explained the problem. "
"If she does not want to change, then maybe you should move on."
)
text = (
f"Previous RoT: {previous_rot}\n"
f"Current Context: {current_context}\n"
f"Current Response: {current_response}"
)
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
with torch.no_grad():
logits = model(**inputs).logits
predicted_label = int(logits.argmax(dim=-1).item())
print(predicted_label)
Interpretation:
0: generate a new RoT.1: carry over the previous RoT.
Limitations
- The model is trained for RoT routing, not final safety classification.
- The decision boundary can be ambiguous when the current response partially preserves the previous RoT while changing its framing or specificity.
- Performance should be interpreted in the context of the RoTRAG pipeline and the ProsocialDialog-derived supervision setup.
Related Repository
GitHub repository: GitLeo1/RoTRAG-Rule-of-Thumb-Reasoning-for-Conversation-Harm-Detection-with-Retrieval-Augmented-Generation
- Downloads last month
- -
Model tree for hgLeo12/RoTRAG_roberta_routing_classifier
Base model
FacebookAI/roberta-large