RQA-R2 / modeling_rqa.py
skatzR's picture
Update modeling_rqa.py
5d2179c verified
import torch
import torch.nn as nn
from typing import List, Optional
from transformers import (
AutoConfig,
AutoModel,
PreTrainedModel,
PretrainedConfig,
)
class RQAModelConfig(PretrainedConfig):
model_type = "rqa_v2_2"
def __init__(
self,
base_model_name: str = "FacebookAI/xlm-roberta-large",
encoder_config: Optional[dict] = None,
error_types: Optional[List[str]] = None,
schema_version: str = "rqa.v2.2",
has_issue_projection_dim: int = 256,
hidden_projection_dim: int = 256,
errors_projection_dim: int = 512,
has_issue_dropout: float = 0.25,
hidden_dropout: float = 0.25,
errors_dropout: float = 0.3,
temperature_has_issue: float = 1.0,
temperature_is_hidden: float = 1.0,
temperature_errors: Optional[List[float]] = None,
threshold_has_issue: float = 0.5,
threshold_is_hidden: float = 0.5,
threshold_error: float = 0.5,
threshold_errors: Optional[List[float]] = None,
**kwargs
):
super().__init__(**kwargs)
self.base_model_name = base_model_name
self.encoder_config = encoder_config
self.error_types = error_types or [
"false_causality",
"unsupported_claim",
"overgeneralization",
"missing_premise",
"contradiction",
"circular_reasoning",
]
self.num_error_types = len(self.error_types)
self.schema_version = schema_version
self.has_issue_projection_dim = has_issue_projection_dim
self.hidden_projection_dim = hidden_projection_dim
self.errors_projection_dim = errors_projection_dim
self.has_issue_dropout = has_issue_dropout
self.hidden_dropout = hidden_dropout
self.errors_dropout = errors_dropout
self.temperature_has_issue = float(temperature_has_issue)
self.temperature_is_hidden = float(temperature_is_hidden)
self.temperature_errors = (
temperature_errors
if temperature_errors is not None
else [1.0] * self.num_error_types
)
self.threshold_has_issue = float(threshold_has_issue)
self.threshold_is_hidden = float(threshold_is_hidden)
self.threshold_error = float(threshold_error)
self.threshold_errors = (
threshold_errors
if threshold_errors is not None
else [float(threshold_error)] * self.num_error_types
)
try:
self._experts_implementation = "eager"
self._experts_implementation_internal = "eager"
except Exception:
pass
class MeanPooling(nn.Module):
def forward(self, last_hidden_state, attention_mask):
mask = attention_mask.unsqueeze(-1).float()
summed = torch.sum(last_hidden_state * mask, dim=1)
denom = torch.clamp(mask.sum(dim=1), min=1e-9)
return summed / denom
class RQAModelHF(PreTrainedModel):
config_class = RQAModelConfig
_supports_grouped_mm = False
def __init__(self, config: RQAModelConfig):
super().__init__(config)
try:
config._experts_implementation = "eager"
config._experts_implementation_internal = "eager"
except Exception:
pass
self.encoder = AutoModel.from_pretrained(config.base_model_name)
hidden_size = self.encoder.config.hidden_size
self.pooler = MeanPooling()
self.has_issue_projection = nn.Sequential(
nn.Linear(hidden_size, config.has_issue_projection_dim),
nn.LayerNorm(config.has_issue_projection_dim),
nn.GELU(),
nn.Dropout(config.has_issue_dropout),
)
self.hidden_projection = nn.Sequential(
nn.Linear(hidden_size, config.hidden_projection_dim),
nn.LayerNorm(config.hidden_projection_dim),
nn.GELU(),
nn.Dropout(config.hidden_dropout),
)
self.errors_projection = nn.Sequential(
nn.Linear(hidden_size, config.errors_projection_dim),
nn.LayerNorm(config.errors_projection_dim),
nn.GELU(),
nn.Dropout(config.errors_dropout),
)
self.has_issue_head = nn.Linear(config.has_issue_projection_dim, 1)
self.is_hidden_head = nn.Linear(config.hidden_projection_dim, 1)
self.errors_head = nn.Linear(
config.errors_projection_dim,
config.num_error_types,
)
self.log_var_has_issue = nn.Parameter(torch.zeros(1))
self.log_var_is_hidden = nn.Parameter(torch.zeros(1))
self.log_var_errors = nn.Parameter(torch.zeros(1))
self._init_custom_weights()
def _init_custom_weights(self):
for module in [
self.has_issue_projection[0],
self.hidden_projection[0],
self.errors_projection[0],
self.has_issue_head,
self.is_hidden_head,
self.errors_head,
]:
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, input_ids=None, attention_mask=None, **kwargs):
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
pooled = self.pooler(outputs.last_hidden_state, attention_mask)
has_issue_logits = self.has_issue_head(
self.has_issue_projection(pooled)
).squeeze(-1)
is_hidden_logits = self.is_hidden_head(
self.hidden_projection(pooled)
).squeeze(-1)
errors_logits = self.errors_head(
self.errors_projection(pooled)
)
return {
"has_issue_logits": has_issue_logits,
"is_hidden_logits": is_hidden_logits,
"errors_logits": errors_logits,
}
AutoConfig.register("rqa_v2_2", RQAModelConfig)
AutoModel.register(RQAModelConfig, RQAModelHF)
print("✅ RQA-R2 зарегистрирован в Transformers")