| import torch |
| import torch.nn as nn |
| from typing import List, Optional |
| from transformers import ( |
| AutoConfig, |
| AutoModel, |
| PreTrainedModel, |
| PretrainedConfig, |
| ) |
|
|
|
|
| class RQAModelConfig(PretrainedConfig): |
| model_type = "rqa_v2_2" |
|
|
| def __init__( |
| self, |
| base_model_name: str = "FacebookAI/xlm-roberta-large", |
| encoder_config: Optional[dict] = None, |
| error_types: Optional[List[str]] = None, |
| schema_version: str = "rqa.v2.2", |
| has_issue_projection_dim: int = 256, |
| hidden_projection_dim: int = 256, |
| errors_projection_dim: int = 512, |
| has_issue_dropout: float = 0.25, |
| hidden_dropout: float = 0.25, |
| errors_dropout: float = 0.3, |
| temperature_has_issue: float = 1.0, |
| temperature_is_hidden: float = 1.0, |
| temperature_errors: Optional[List[float]] = None, |
| threshold_has_issue: float = 0.5, |
| threshold_is_hidden: float = 0.5, |
| threshold_error: float = 0.5, |
| threshold_errors: Optional[List[float]] = None, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
|
|
| self.base_model_name = base_model_name |
| self.encoder_config = encoder_config |
| self.error_types = error_types or [ |
| "false_causality", |
| "unsupported_claim", |
| "overgeneralization", |
| "missing_premise", |
| "contradiction", |
| "circular_reasoning", |
| ] |
| self.num_error_types = len(self.error_types) |
| self.schema_version = schema_version |
|
|
| self.has_issue_projection_dim = has_issue_projection_dim |
| self.hidden_projection_dim = hidden_projection_dim |
| self.errors_projection_dim = errors_projection_dim |
|
|
| self.has_issue_dropout = has_issue_dropout |
| self.hidden_dropout = hidden_dropout |
| self.errors_dropout = errors_dropout |
|
|
| self.temperature_has_issue = float(temperature_has_issue) |
| self.temperature_is_hidden = float(temperature_is_hidden) |
| self.temperature_errors = ( |
| temperature_errors |
| if temperature_errors is not None |
| else [1.0] * self.num_error_types |
| ) |
|
|
| self.threshold_has_issue = float(threshold_has_issue) |
| self.threshold_is_hidden = float(threshold_is_hidden) |
| self.threshold_error = float(threshold_error) |
| self.threshold_errors = ( |
| threshold_errors |
| if threshold_errors is not None |
| else [float(threshold_error)] * self.num_error_types |
| ) |
|
|
| try: |
| self._experts_implementation = "eager" |
| self._experts_implementation_internal = "eager" |
| except Exception: |
| pass |
|
|
|
|
| class MeanPooling(nn.Module): |
| def forward(self, last_hidden_state, attention_mask): |
| mask = attention_mask.unsqueeze(-1).float() |
| summed = torch.sum(last_hidden_state * mask, dim=1) |
| denom = torch.clamp(mask.sum(dim=1), min=1e-9) |
| return summed / denom |
|
|
|
|
| class RQAModelHF(PreTrainedModel): |
| config_class = RQAModelConfig |
| _supports_grouped_mm = False |
|
|
| def __init__(self, config: RQAModelConfig): |
| super().__init__(config) |
|
|
| try: |
| config._experts_implementation = "eager" |
| config._experts_implementation_internal = "eager" |
| except Exception: |
| pass |
|
|
| self.encoder = AutoModel.from_pretrained(config.base_model_name) |
| hidden_size = self.encoder.config.hidden_size |
|
|
| self.pooler = MeanPooling() |
|
|
| self.has_issue_projection = nn.Sequential( |
| nn.Linear(hidden_size, config.has_issue_projection_dim), |
| nn.LayerNorm(config.has_issue_projection_dim), |
| nn.GELU(), |
| nn.Dropout(config.has_issue_dropout), |
| ) |
|
|
| self.hidden_projection = nn.Sequential( |
| nn.Linear(hidden_size, config.hidden_projection_dim), |
| nn.LayerNorm(config.hidden_projection_dim), |
| nn.GELU(), |
| nn.Dropout(config.hidden_dropout), |
| ) |
|
|
| self.errors_projection = nn.Sequential( |
| nn.Linear(hidden_size, config.errors_projection_dim), |
| nn.LayerNorm(config.errors_projection_dim), |
| nn.GELU(), |
| nn.Dropout(config.errors_dropout), |
| ) |
|
|
| self.has_issue_head = nn.Linear(config.has_issue_projection_dim, 1) |
| self.is_hidden_head = nn.Linear(config.hidden_projection_dim, 1) |
| self.errors_head = nn.Linear( |
| config.errors_projection_dim, |
| config.num_error_types, |
| ) |
|
|
| self.log_var_has_issue = nn.Parameter(torch.zeros(1)) |
| self.log_var_is_hidden = nn.Parameter(torch.zeros(1)) |
| self.log_var_errors = nn.Parameter(torch.zeros(1)) |
|
|
| self._init_custom_weights() |
|
|
| def _init_custom_weights(self): |
| for module in [ |
| self.has_issue_projection[0], |
| self.hidden_projection[0], |
| self.errors_projection[0], |
| self.has_issue_head, |
| self.is_hidden_head, |
| self.errors_head, |
| ]: |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| def forward(self, input_ids=None, attention_mask=None, **kwargs): |
| outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| return_dict=True, |
| ) |
|
|
| pooled = self.pooler(outputs.last_hidden_state, attention_mask) |
|
|
| has_issue_logits = self.has_issue_head( |
| self.has_issue_projection(pooled) |
| ).squeeze(-1) |
|
|
| is_hidden_logits = self.is_hidden_head( |
| self.hidden_projection(pooled) |
| ).squeeze(-1) |
|
|
| errors_logits = self.errors_head( |
| self.errors_projection(pooled) |
| ) |
|
|
| return { |
| "has_issue_logits": has_issue_logits, |
| "is_hidden_logits": is_hidden_logits, |
| "errors_logits": errors_logits, |
| } |
|
|
|
|
| AutoConfig.register("rqa_v2_2", RQAModelConfig) |
| AutoModel.register(RQAModelConfig, RQAModelHF) |
|
|
| print("✅ RQA-R2 зарегистрирован в Transformers") |
|
|