convexray commited on
Commit
db7bdf2
·
verified ·
1 Parent(s): ee26f70

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +355 -0
README.md CHANGED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ library_name: transformers
6
+ tags:
7
+ - citation-verification
8
+ - retrieval-augmented-generation
9
+ - rag
10
+ - cross-lingual
11
+ - deberta
12
+ - cross-encoder
13
+ - nli
14
+ - attribution
15
+ pipeline_tag: text-classification
16
+ datasets:
17
+ - fever
18
+ - din0s/asqa
19
+ - miracl/hagrid
20
+ metrics:
21
+ - f1
22
+ - precision
23
+ - recall
24
+ - accuracy
25
+ - roc_auc
26
+ base_model: microsoft/deberta-v3-base
27
+ model-index:
28
+ - name: dualtrack-alignment-module
29
+ results:
30
+ - task:
31
+ type: text-classification
32
+ name: Citation Verification
33
+ metrics:
34
+ - type: f1
35
+ value: 0.89
36
+ name: F1 Score
37
+ - type: accuracy
38
+ value: 0.87
39
+ name: Accuracy
40
+ - type: roc_auc
41
+ value: 0.94
42
+ name: ROC-AUC
43
+ ---
44
+
45
+ # DualTrack Alignment Module
46
+
47
+ > **Anonymous submission to ACL 2026**
48
+
49
+ A cross-encoder model for detecting **citation drift** in Retrieval-Augmented Generation (RAG) systems. Given a user-facing claim, an evidence representation, and a source passage, the model predicts whether the citation is valid (the source supports the claim).
50
+
51
+ ## Model Description
52
+
53
+ This model addresses a critical reliability problem in RAG systems: **citation drift**, where generated text diverges from source documents in ways that break attribution. The problem is particularly severe in cross-lingual settings where the answer language differs from source document language.
54
+
55
+ ### Architecture
56
+
57
+ ```
58
+ Input: "[CLS] User claim: {claim} [SEP] Evidence: {evidence} [SEP] Source passage: {context} [SEP]"
59
+
60
+ DeBERTa-v3-base (184M parameters)
61
+
62
+ [CLS] embedding (768-dim)
63
+
64
+ Linear(768, 2) → Softmax
65
+
66
+ Output: P(valid citation)
67
+ ```
68
+
69
+ ### Why Cross-Encoder?
70
+
71
+ Unlike embedding-based approaches that encode texts separately, the cross-encoder sees all three components **together**, enabling:
72
+ - Cross-attention between claim and source
73
+ - Detection of subtle semantic mismatches
74
+ - Better handling of paraphrases vs. factual errors
75
+
76
+ ## Intended Use
77
+
78
+ ### Primary Use Cases
79
+
80
+ 1. **Post-hoc citation verification**: Validate citations in RAG outputs before serving to users
81
+ 2. **Citation drift detection**: Identify claims that have semantically drifted from their sources
82
+ 3. **Training signal**: Provide rewards for citation-aware generation
83
+
84
+ ### Out of Scope
85
+
86
+ - General NLI/entailment (model is specialized for RAG citation patterns)
87
+ - Fact-checking against world knowledge (requires source passage)
88
+ - Non-English source documents (trained on English sources only)
89
+
90
+ ## How to Use
91
+
92
+ ### Installation
93
+
94
+ ```bash
95
+ pip install transformers torch
96
+ ```
97
+
98
+ ### Basic Usage
99
+
100
+ ```python
101
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
102
+ import torch
103
+
104
+ # Load model
105
+ model_name = "anonymous-acl2026/dualtrack-alignment" # Replace with actual path
106
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
107
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
108
+ model.eval()
109
+
110
+ def check_citation(user_claim: str, evidence: str, source: str, threshold: float = 0.5) -> tuple[bool, float]:
111
+ """
112
+ Check if a citation is valid.
113
+
114
+ Args:
115
+ user_claim: The claim shown to the user
116
+ evidence: Evidence track representation (can be same as user_claim)
117
+ source: The source passage being cited
118
+ threshold: Classification threshold (default from training)
119
+
120
+ Returns:
121
+ (is_valid, probability)
122
+ """
123
+ # Format input
124
+ text = f"User claim: {user_claim}\n\nEvidence: {evidence}\n\nSource passage: {source}"
125
+
126
+ # Tokenize
127
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
128
+
129
+ # Predict
130
+ with torch.no_grad():
131
+ outputs = model(**inputs)
132
+ prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item()
133
+
134
+ return prob >= threshold, prob
135
+
136
+ # Example: Valid citation
137
+ is_valid, prob = check_citation(
138
+ user_claim="Python was created by Guido van Rossum.",
139
+ evidence="Python was created by Guido van Rossum.",
140
+ source="Python is a programming language created by Guido van Rossum in 1991."
141
+ )
142
+ print(f"Valid: {is_valid}, Probability: {prob:.3f}")
143
+ # Output: Valid: True, Probability: 0.95
144
+
145
+ # Example: Invalid citation (wrong date)
146
+ is_valid, prob = check_citation(
147
+ user_claim="Python was created in 1989.",
148
+ evidence="Python was created in 1989.",
149
+ source="Python is a programming language created by Guido van Rossum in 1991."
150
+ )
151
+ print(f"Valid: {is_valid}, Probability: {prob:.3f}")
152
+ # Output: Valid: False, Probability: 0.12
153
+ ```
154
+
155
+ ### Batch Processing
156
+
157
+ ```python
158
+ def batch_check_citations(examples: list[dict], batch_size: int = 16) -> list[float]:
159
+ """
160
+ Check multiple citations efficiently.
161
+
162
+ Args:
163
+ examples: List of dicts with keys 'user', 'evidence', 'source'
164
+ batch_size: Batch size for inference
165
+
166
+ Returns:
167
+ List of probabilities
168
+ """
169
+ all_probs = []
170
+
171
+ for i in range(0, len(examples), batch_size):
172
+ batch = examples[i:i + batch_size]
173
+
174
+ texts = [
175
+ f"User claim: {ex['user']}\n\nEvidence: {ex['evidence']}\n\nSource passage: {ex['source']}"
176
+ for ex in batch
177
+ ]
178
+
179
+ inputs = tokenizer(
180
+ texts,
181
+ return_tensors="pt",
182
+ truncation=True,
183
+ max_length=512,
184
+ padding=True
185
+ )
186
+
187
+ with torch.no_grad():
188
+ outputs = model(**inputs)
189
+ probs = torch.softmax(outputs.logits, dim=-1)[:, 1].tolist()
190
+
191
+ all_probs.extend(probs)
192
+
193
+ return all_probs
194
+ ```
195
+
196
+ ### Integration with DualTrack
197
+
198
+ ```python
199
+ class DualTrackAlignmentModule:
200
+ """
201
+ Alignment module for the DualTrack RAG system.
202
+
203
+ Detects citation drift between user track and source documents.
204
+ """
205
+
206
+ def __init__(self, model_path: str, threshold: float = None, device: str = None):
207
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
208
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
209
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
210
+ self.model.to(self.device)
211
+ self.model.eval()
212
+
213
+ # Load optimal threshold from metadata
214
+ import json
215
+ import os
216
+ metadata_path = os.path.join(model_path, "metadata.json")
217
+ if os.path.exists(metadata_path):
218
+ with open(metadata_path) as f:
219
+ metadata = json.load(f)
220
+ self.threshold = threshold or metadata.get("optimal_threshold", 0.5)
221
+ else:
222
+ self.threshold = threshold or 0.5
223
+
224
+ def detect_drift(
225
+ self,
226
+ user_claims: list[str],
227
+ evidence_claims: list[str],
228
+ sources: list[str]
229
+ ) -> list[dict]:
230
+ """
231
+ Detect citation drift for multiple claim-source pairs.
232
+
233
+ Returns list of {is_valid, probability, drift_detected}.
234
+ """
235
+ results = []
236
+
237
+ for user, evidence, source in zip(user_claims, evidence_claims, sources):
238
+ text = f"User claim: {user}\n\nEvidence: {evidence}\n\nSource passage: {source}"
239
+
240
+ inputs = self.tokenizer(
241
+ text, return_tensors="pt", truncation=True, max_length=512
242
+ ).to(self.device)
243
+
244
+ with torch.no_grad():
245
+ outputs = self.model(**inputs)
246
+ prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item()
247
+
248
+ results.append({
249
+ "is_valid": prob >= self.threshold,
250
+ "probability": prob,
251
+ "drift_detected": prob < self.threshold
252
+ })
253
+
254
+ return results
255
+ ```
256
+
257
+ ## Training Details
258
+
259
+ ### Training Data
260
+
261
+ The model was trained on a curated dataset combining multiple sources:
262
+
263
+ | Source | Examples | Description |
264
+ |--------|----------|-------------|
265
+ | FEVER | ~8,000 | Fact verification with SUPPORTS/REFUTES labels |
266
+ | HAGRID | ~2,000 | Attributed QA with quote-based evidence |
267
+ | ASQA | ~3,000 | Ambiguous questions with long-form answers |
268
+
269
+ **Label Generation (V3 - LLM-Supervised)**:
270
+ - Training labels verified by GPT-4o-mini ("Does context support claim?")
271
+ - Evaluation uses independent NLI model (DeBERTa-MNLI)
272
+ - This breaks circularity: model learns LLM judgment, evaluated by NLI
273
+
274
+ **Data Augmentation**:
275
+ - **Negative perturbations**: date_change, number_change, entity_swap, false_detail, negation, topic_drift
276
+ - **Positive perturbations**: paraphrase, synonym_swap, formal_informal register changes
277
+
278
+ ### Training Procedure
279
+
280
+ | Hyperparameter | Value |
281
+ |----------------|-------|
282
+ | Base model | `microsoft/deberta-v3-base` |
283
+ | Max sequence length | 512 |
284
+ | Batch size | 8 |
285
+ | Gradient accumulation | 2 |
286
+ | Effective batch size | 16 |
287
+ | Learning rate | 2e-5 |
288
+ | Warmup ratio | 0.1 |
289
+ | Weight decay | 0.01 |
290
+ | Epochs | 5 |
291
+ | Early stopping patience | 3 |
292
+ | FP16 training | Yes |
293
+ | Optimizer | AdamW |
294
+
295
+ **Training Infrastructure**:
296
+ - Single GPU (NVIDIA T4/V100)
297
+ - Training time: ~2-3 hours
298
+ - Framework: HuggingFace Transformers + PyTorch
299
+
300
+ ### Evaluation
301
+
302
+ **Validation Set Performance** (15% held-out, stratified):
303
+
304
+ | Metric | Score |
305
+ |--------|-------|
306
+ | Accuracy | 0.87 |
307
+ | Precision | 0.88 |
308
+ | Recall | 0.90 |
309
+ | F1 | 0.89 |
310
+ | ROC-AUC | 0.94 |
311
+
312
+ **Optimal Threshold**: 0.50 (determined via F1 maximization on validation set)
313
+
314
+ **Performance by Perturbation Type**:
315
+
316
+ | Type | Accuracy | Notes |
317
+ |------|----------|-------|
318
+ | original | 0.91 | Clean examples |
319
+ | paraphrase | 0.88 | Meaning-preserving rewrites |
320
+ | entity_swap | 0.94 | Wrong person/place/org |
321
+ | date_change | 0.92 | Incorrect dates |
322
+ | negation | 0.89 | Reversed claims |
323
+ | topic_drift | 0.85 | Subtle semantic shifts |
324
+
325
+ ## Limitations
326
+
327
+ 1. **English only**: Trained on English source passages. Cross-lingual application requires translation or multilingual encoder.
328
+
329
+ 2. **RAG-specific**: Optimized for RAG citation patterns; may not generalize to arbitrary NLI tasks.
330
+
331
+ 3. **Passage length**: Max 512 tokens. Long documents require chunking or summarization.
332
+
333
+ 4. **Threshold sensitivity**: Default threshold (0.5) may need tuning for specific applications. High-precision applications should use higher thresholds.
334
+
335
+ 5. **Training data bias**: Performance may vary on domains not represented in FEVER/HAGRID/ASQA (e.g., legal, medical, code).
336
+
337
+ ## Ethical Considerations
338
+
339
+ ### Intended Benefits
340
+ - Improved reliability of AI-generated citations
341
+ - Reduced misinformation from RAG hallucinations
342
+ - Better transparency in AI-assisted research
343
+
344
+ ### Potential Risks
345
+ - Over-reliance on automated verification (human review still recommended for high-stakes applications)
346
+ - False negatives may incorrectly flag valid citations
347
+ - False positives may miss genuine attribution errors
348
+
349
+ ### Recommendations
350
+ - Use as one signal among many, not sole arbiter
351
+ - Monitor performance on domain-specific data
352
+ - Combine with human review for critical applications
353
+
354
+
355
+ *This model is part of an anonymous submission to ACL 2026. Author information will be added upon acceptance.*