|
|
"""IRouterLM Model - RAG Strategy Router Model.""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, Qwen3Model |
|
|
|
|
|
from .configuration_irouterlm import IRouterLMConfig |
|
|
|
|
|
|
|
|
class IRouterLMModel(PreTrainedModel): |
|
|
""" |
|
|
IRouterLM: Intelligent Router for RAG Strategy Selection. |
|
|
|
|
|
A Qwen3-0.6B based model fine-tuned for classifying queries |
|
|
into optimal RAG retrieval strategies. |
|
|
|
|
|
Strategies: |
|
|
0: MULTIMODAL_RERANK - Multimodal retrieval with reranking |
|
|
1: MULTIMODAL-SINGLE - Single-stage multimodal retrieval |
|
|
2: TEXT_RERANK - Text-only retrieval with reranking |
|
|
3: TEXT-SINGLE - Single-stage text retrieval |
|
|
""" |
|
|
|
|
|
config_class = IRouterLMConfig |
|
|
_no_split_modules = ["Qwen3DecoderLayer"] |
|
|
|
|
|
def __init__(self, config: IRouterLMConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.transformer = Qwen3Model.from_pretrained( |
|
|
config.base_model_name, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(config.classifier_dropout) |
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize classifier weights.""" |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.normal_(module.weight, std=0.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor = None, |
|
|
labels: torch.Tensor = None, |
|
|
output_hidden_states: bool = None, |
|
|
return_dict: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Forward pass for strategy classification. |
|
|
""" |
|
|
|
|
|
outputs = self.transformer( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
output_hidden_states=True, |
|
|
) |
|
|
|
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
|
|
if attention_mask is not None: |
|
|
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() |
|
|
sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1) |
|
|
sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9) |
|
|
pooled = sum_hidden / sum_mask |
|
|
else: |
|
|
pooled = hidden_states.mean(dim=1) |
|
|
|
|
|
|
|
|
pooled = self.dropout(pooled) |
|
|
logits = self.classifier(pooled) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = self._compute_loss(logits, labels) |
|
|
|
|
|
return {"loss": loss, "logits": logits} |
|
|
|
|
|
def _compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
|
|
"""Compute weighted KL divergence loss for soft labels.""" |
|
|
EPS = 1e-8 |
|
|
reward_sum = labels.sum(dim=-1, keepdim=True) |
|
|
labels_normalized = labels / (reward_sum + EPS) |
|
|
log_probs = torch.nn.functional.log_softmax(logits, dim=-1) |
|
|
sample_losses = -(labels_normalized * log_probs).sum(dim=-1) |
|
|
sample_weights = labels.max(dim=-1)[0] |
|
|
return (sample_losses * sample_weights).mean() |
|
|
|
|
|
def predict(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None): |
|
|
""" |
|
|
Predict the best RAG strategy for given queries. |
|
|
""" |
|
|
self.eval() |
|
|
with torch.no_grad(): |
|
|
outputs = self.forward(input_ids, attention_mask) |
|
|
probs = torch.softmax(outputs["logits"], dim=-1) |
|
|
predictions = probs.argmax(dim=-1) |
|
|
|
|
|
return { |
|
|
"predictions": predictions, |
|
|
"probabilities": probs, |
|
|
"strategy_names": [self.config.strategy_names[p.item()] for p in predictions], |
|
|
} |
|
|
|