convexray's picture
Update README.md
c4c1f5b verified
metadata
language:
  - en
license: cc-by-nc-2.0
library_name: transformers
tags:
  - citation-verification
  - retrieval-augmented-generation
  - rag
  - cross-lingual
  - deberta
  - cross-encoder
  - nli
  - attribution
pipeline_tag: text-classification
datasets:
  - fever
  - din0s/asqa
  - miracl/hagrid
metrics:
  - f1
  - precision
  - recall
  - accuracy
  - roc_auc
base_model: microsoft/deberta-v3-base
model-index:
  - name: dualtrack-alignment-module
    results:
      - task:
          type: text-classification
          name: Citation Verification
        metrics:
          - type: f1
            value: 0.89
            name: F1 Score
          - type: accuracy
            value: 0.87
            name: Accuracy
          - type: roc_auc
            value: 0.94
            name: ROC-AUC

DualTrack Alignment Module

Anonymous submission to ACL 2026

A cross-encoder model for detecting citation drift in Retrieval-Augmented Generation (RAG) systems. Given a user-facing claim, an evidence representation, and a source passage, the model predicts whether the citation is valid (the source supports the claim).

Model Description

This model addresses a critical reliability problem in RAG systems: citation drift, where generated text diverges from source documents in ways that break attribution. The problem is particularly severe in cross-lingual settings where the answer language differs from source document language.

Architecture

Input: "[CLS] User claim: {claim} [SEP] Evidence: {evidence} [SEP] Source passage: {context} [SEP]"
         ↓
    DeBERTa-v3-base (184M parameters)
         ↓
    [CLS] embedding (768-dim)
         ↓
    Linear(768, 2) → Softmax
         ↓
    Output: P(valid citation)

Why Cross-Encoder?

Unlike embedding-based approaches that encode texts separately, the cross-encoder sees all three components together, enabling:

  • Cross-attention between claim and source
  • Detection of subtle semantic mismatches
  • Better handling of paraphrases vs. factual errors

Intended Use

Primary Use Cases

  1. Post-hoc citation verification: Validate citations in RAG outputs before serving to users
  2. Citation drift detection: Identify claims that have semantically drifted from their sources
  3. Training signal: Provide rewards for citation-aware generation

Out of Scope

  • General NLI/entailment (model is specialized for RAG citation patterns)
  • Fact-checking against world knowledge (requires source passage)
  • Non-English source documents (trained on English sources only)

How to Use

Installation

pip install transformers torch

Basic Usage

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Load model
model_name = "anonymous-acl2026/dualtrack-alignment"  # Replace with actual path
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()

def check_citation(user_claim: str, evidence: str, source: str, threshold: float = 0.5) -> tuple[bool, float]:
    """
    Check if a citation is valid.
    
    Args:
        user_claim: The claim shown to the user
        evidence: Evidence track representation (can be same as user_claim)
        source: The source passage being cited
        threshold: Classification threshold (default from training)
    
    Returns:
        (is_valid, probability)
    """
    # Format input
    text = f"User claim: {user_claim}\n\nEvidence: {evidence}\n\nSource passage: {source}"
    
    # Tokenize
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    
    # Predict
    with torch.no_grad():
        outputs = model(**inputs)
        prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item()
    
    return prob >= threshold, prob

# Example: Valid citation
is_valid, prob = check_citation(
    user_claim="Python was created by Guido van Rossum.",
    evidence="Python was created by Guido van Rossum.",
    source="Python is a programming language created by Guido van Rossum in 1991."
)
print(f"Valid: {is_valid}, Probability: {prob:.3f}")
# Output: Valid: True, Probability: 0.95

# Example: Invalid citation (wrong date)
is_valid, prob = check_citation(
    user_claim="Python was created in 1989.",
    evidence="Python was created in 1989.",
    source="Python is a programming language created by Guido van Rossum in 1991."
)
print(f"Valid: {is_valid}, Probability: {prob:.3f}")
# Output: Valid: False, Probability: 0.12

Batch Processing

def batch_check_citations(examples: list[dict], batch_size: int = 16) -> list[float]:
    """
    Check multiple citations efficiently.
    
    Args:
        examples: List of dicts with keys 'user', 'evidence', 'source'
        batch_size: Batch size for inference
    
    Returns:
        List of probabilities
    """
    all_probs = []
    
    for i in range(0, len(examples), batch_size):
        batch = examples[i:i + batch_size]
        
        texts = [
            f"User claim: {ex['user']}\n\nEvidence: {ex['evidence']}\n\nSource passage: {ex['source']}"
            for ex in batch
        ]
        
        inputs = tokenizer(
            texts, 
            return_tensors="pt", 
            truncation=True, 
            max_length=512, 
            padding=True
        )
        
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.softmax(outputs.logits, dim=-1)[:, 1].tolist()
        
        all_probs.extend(probs)
    
    return all_probs

Integration with DualTrack

class DualTrackAlignmentModule:
    """
    Alignment module for the DualTrack RAG system.
    
    Detects citation drift between user track and source documents.
    """
    
    def __init__(self, model_path: str, threshold: float = None, device: str = None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self.model.to(self.device)
        self.model.eval()
        
        # Load optimal threshold from metadata
        import json
        import os
        metadata_path = os.path.join(model_path, "metadata.json")
        if os.path.exists(metadata_path):
            with open(metadata_path) as f:
                metadata = json.load(f)
            self.threshold = threshold or metadata.get("optimal_threshold", 0.5)
        else:
            self.threshold = threshold or 0.5
    
    def detect_drift(
        self, 
        user_claims: list[str], 
        evidence_claims: list[str], 
        sources: list[str]
    ) -> list[dict]:
        """
        Detect citation drift for multiple claim-source pairs.
        
        Returns list of {is_valid, probability, drift_detected}.
        """
        results = []
        
        for user, evidence, source in zip(user_claims, evidence_claims, sources):
            text = f"User claim: {user}\n\nEvidence: {evidence}\n\nSource passage: {source}"
            
            inputs = self.tokenizer(
                text, return_tensors="pt", truncation=True, max_length=512
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item()
            
            results.append({
                "is_valid": prob >= self.threshold,
                "probability": prob,
                "drift_detected": prob < self.threshold
            })
        
        return results

Training Details

Training Data

The model was trained on a curated dataset combining multiple sources:

Source Examples Description
FEVER ~8,000 Fact verification with SUPPORTS/REFUTES labels
HAGRID ~2,000 Attributed QA with quote-based evidence
ASQA ~3,000 Ambiguous questions with long-form answers

Label Generation (V3 - LLM-Supervised):

  • Training labels verified by GPT-4o-mini ("Does context support claim?")
  • Evaluation uses independent NLI model (DeBERTa-MNLI)
  • This breaks circularity: model learns LLM judgment, evaluated by NLI

Data Augmentation:

  • Negative perturbations: date_change, number_change, entity_swap, false_detail, negation, topic_drift
  • Positive perturbations: paraphrase, synonym_swap, formal_informal register changes

Training Procedure

Hyperparameter Value
Base model microsoft/deberta-v3-base
Max sequence length 512
Batch size 8
Gradient accumulation 2
Effective batch size 16
Learning rate 2e-5
Warmup ratio 0.1
Weight decay 0.01
Epochs 5
Early stopping patience 3
FP16 training Yes
Optimizer AdamW

Training Infrastructure:

  • Single GPU (NVIDIA T4/V100)
  • Training time: ~2-3 hours
  • Framework: HuggingFace Transformers + PyTorch

Evaluation

Validation Set Performance (15% held-out, stratified):

Metric Score
Accuracy 0.87
Precision 0.88
Recall 0.90
F1 0.89
ROC-AUC 0.94

Optimal Threshold: 0.50 (determined via F1 maximization on validation set)

Performance by Perturbation Type:

Type Accuracy Notes
original 0.91 Clean examples
paraphrase 0.88 Meaning-preserving rewrites
entity_swap 0.94 Wrong person/place/org
date_change 0.92 Incorrect dates
negation 0.89 Reversed claims
topic_drift 0.85 Subtle semantic shifts

Limitations

  1. English only: Trained on English source passages. Cross-lingual application requires translation or multilingual encoder.

  2. RAG-specific: Optimized for RAG citation patterns; may not generalize to arbitrary NLI tasks.

  3. Passage length: Max 512 tokens. Long documents require chunking or summarization.

  4. Threshold sensitivity: Default threshold (0.5) may need tuning for specific applications. High-precision applications should use higher thresholds.

  5. Training data bias: Performance may vary on domains not represented in FEVER/HAGRID/ASQA (e.g., legal, medical, code).

Ethical Considerations

Intended Benefits

  • Improved reliability of AI-generated citations
  • Reduced misinformation from RAG hallucinations
  • Better transparency in AI-assisted research

Potential Risks

  • Over-reliance on automated verification (human review still recommended for high-stakes applications)
  • False negatives may incorrectly flag valid citations
  • False positives may miss genuine attribution errors

Recommendations

  • Use as one signal among many, not sole arbiter
  • Monitor performance on domain-specific data
  • Combine with human review for critical applications

This model is part of an anonymous submission to ACL 2026. Author information will be added upon acceptance.