RoTRAG RoBERTa Routing Classifier

This model is the routing classifier used in RoTRAG: Retrieval-Augmented Rule-of-Thumb Reasoning for Conversation Harm Detection.

The classifier decides whether a previously generated Rule of Thumb (RoT) can be carried over to the current dialogue turn or whether RoTRAG should generate a new RoT for the current context.

Label Space

Label Meaning
0 The previous RoT is no longer sufficient; generate a new RoT.
1 The previous RoT can be retained for the current turn.

Intended Use

This checkpoint is intended for the router step in the RoTRAG pipeline. Given the previous RoT information and the current context-response pair, it predicts whether the normative core of the previous RoT still applies.

It is not a standalone harm detector. The classifier only decides RoT carry-over versus RoT regeneration; downstream safety or prosocial labels are produced by separate RoTRAG reasoning and prediction steps.

Model Details

  • Architecture: RobertaForSequenceClassification
  • Base model: roberta-large
  • Task: binary sequence classification
  • Input type: previous RoT(s), current dialogue context, and current response
  • Output: binary routing label (0 or 1)

Training Data

The routing supervision data was constructed from ProsocialDialog train/validation instances. The annotation process first used 1,000 human-labeled routing examples, annotated by 10 expert annotators. The final gold label for each example was determined by hard voting.

The resulting annotation rule was distilled into a prompt and validated against the human-labeled subset, reaching 0.98 accuracy relative to the human gold labels. The validated prompt was then used to expand labels over the remaining training data.

Expanded label counts reported for classifier training:

Split Label 0 Label 1
Train 14,528 46,206
Validation 2,489 8,147

Training Setup

The classifier was fine-tuned with the following setup:

Hyperparameter Value
Base model roberta-large
Epochs 3
Learning rate 2e-5
Scheduler Linear, no warm-up
Optimizer AdamW
Weight decay 0.01
Train batch size 8
Gradient accumulation 2
Effective batch size 16
Evaluation batch size 16
Max sequence length 256
Padding Dynamic
Mixed precision FP16
Best checkpoint criterion Macro F1
Random seed 42

Evaluation

On 10,636 validation routing instances, the classifier produced the following confusion matrix counts:

Outcome Count Share
True positive 7,461 70.215%
True negative 1,743 16.403%
False positive 746 7.021%
False negative 676 6.362%

The reported behavior is strongest when the current turn clearly preserves or clearly departs from the previous RoT. Remaining errors are often borderline cases where the current response is partially related to the previous RoT but shifts specificity, framing, or normative emphasis.

Example Usage

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

model_id = "hgLeo12/roberta_classifier"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id)
model.eval()

previous_rot = "You should not manipulate others."
current_context = "She gets what she gets. Her bad habits are getting out of control."
current_response = (
    "You would have more luck if you just talked to her and explained the problem. "
    "If she does not want to change, then maybe you should move on."
)

text = (
    f"Previous RoT: {previous_rot}\n"
    f"Current Context: {current_context}\n"
    f"Current Response: {current_response}"
)

inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)

with torch.no_grad():
    logits = model(**inputs).logits

predicted_label = int(logits.argmax(dim=-1).item())
print(predicted_label)

Interpretation:

  • 0: generate a new RoT.
  • 1: carry over the previous RoT.

Limitations

  • The model is trained for RoT routing, not final safety classification.
  • The decision boundary can be ambiguous when the current response partially preserves the previous RoT while changing its framing or specificity.
  • Performance should be interpreted in the context of the RoTRAG pipeline and the ProsocialDialog-derived supervision setup.

Related Repository

GitHub repository: GitLeo1/RoTRAG-Rule-of-Thumb-Reasoning-for-Conversation-Harm-Detection-with-Retrieval-Augmented-Generation

Downloads last month
-
Safetensors
Model size
0.4B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for hgLeo12/RoTRAG_roberta_routing_classifier

Finetuned
(468)
this model