IRouterLM / modeling_irouterlm.py
ananoymous's picture
Upload IRouterLM model
19a4f4d verified
"""IRouterLM Model - RAG Strategy Router Model."""
import torch
import torch.nn as nn
from transformers import PreTrainedModel, Qwen3Model
from .configuration_irouterlm import IRouterLMConfig
class IRouterLMModel(PreTrainedModel):
"""
IRouterLM: Intelligent Router for RAG Strategy Selection.
A Qwen3-0.6B based model fine-tuned for classifying queries
into optimal RAG retrieval strategies.
Strategies:
0: MULTIMODAL_RERANK - Multimodal retrieval with reranking
1: MULTIMODAL-SINGLE - Single-stage multimodal retrieval
2: TEXT_RERANK - Text-only retrieval with reranking
3: TEXT-SINGLE - Single-stage text retrieval
"""
config_class = IRouterLMConfig
_no_split_modules = ["Qwen3DecoderLayer"]
def __init__(self, config: IRouterLMConfig):
super().__init__(config)
# Load base Qwen3 model
self.transformer = Qwen3Model.from_pretrained(
config.base_model_name,
trust_remote_code=True,
)
# Classification head
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights
self.post_init()
def _init_weights(self, module):
"""Initialize classifier weights."""
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None,
output_hidden_states: bool = None,
return_dict: bool = True,
**kwargs,
):
"""
Forward pass for strategy classification.
"""
# Get base model outputs
outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
# Mean pooling over sequence dimension
hidden_states = outputs.last_hidden_state
if attention_mask is not None:
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
pooled = sum_hidden / sum_mask
else:
pooled = hidden_states.mean(dim=1)
# Classification
pooled = self.dropout(pooled)
logits = self.classifier(pooled)
loss = None
if labels is not None:
loss = self._compute_loss(logits, labels)
return {"loss": loss, "logits": logits}
def _compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Compute weighted KL divergence loss for soft labels."""
EPS = 1e-8
reward_sum = labels.sum(dim=-1, keepdim=True)
labels_normalized = labels / (reward_sum + EPS)
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
sample_losses = -(labels_normalized * log_probs).sum(dim=-1)
sample_weights = labels.max(dim=-1)[0]
return (sample_losses * sample_weights).mean()
def predict(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
"""
Predict the best RAG strategy for given queries.
"""
self.eval()
with torch.no_grad():
outputs = self.forward(input_ids, attention_mask)
probs = torch.softmax(outputs["logits"], dim=-1)
predictions = probs.argmax(dim=-1)
return {
"predictions": predictions,
"probabilities": probs,
"strategy_names": [self.config.strategy_names[p.item()] for p in predictions],
}