skatzR commited on
Commit
5ffc4dc
·
verified ·
1 Parent(s): 23b4804

Upload 3 files

Browse files
Files changed (3) hide show
  1. __init__.py +15 -0
  2. inference.py +356 -0
  3. modeling_rqa.py +214 -0
__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+
3
+ from .modeling_rqa import RQAModelConfig, RQAModelHF
4
+
5
+ __all__ = ["RQAModelConfig", "RQAModelHF"]
6
+
7
+ try:
8
+ AutoConfig.register("rqa_v2_2", RQAModelConfig)
9
+ except ValueError:
10
+ pass
11
+
12
+ try:
13
+ AutoModel.register(RQAModelConfig, RQAModelHF)
14
+ except ValueError:
15
+ pass
inference.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import torch
5
+ from transformers import AutoTokenizer
6
+
7
+ try:
8
+ from huggingface_hub import hf_hub_download
9
+ except Exception:
10
+ hf_hub_download = None
11
+
12
+ try:
13
+ from .modeling_rqa import RQAModelHF
14
+ except ImportError:
15
+ from modeling_rqa import RQAModelHF
16
+
17
+
18
+ ERROR_NAMES_RU = {
19
+ "false_causality": "Ложная причинно-следственная связь",
20
+ "unsupported_claim": "Неподкрепленное утверждение",
21
+ "overgeneralization": "Чрезмерное обобщение",
22
+ "missing_premise": "Отсутствующая предпосылка",
23
+ "contradiction": "Противоречие",
24
+ "circular_reasoning": "Круговое рассуждение",
25
+ }
26
+
27
+
28
+ def _resolve_calibration_path(model_path: str) -> Optional[str]:
29
+ local_path = os.path.join(model_path, "calibration_data.pth")
30
+ if os.path.exists(local_path):
31
+ return local_path
32
+
33
+ if hf_hub_download is None or os.path.isdir(model_path):
34
+ return None
35
+
36
+ try:
37
+ return hf_hub_download(
38
+ repo_id=model_path,
39
+ filename="calibration_data.pth",
40
+ )
41
+ except Exception:
42
+ return None
43
+
44
+
45
+ class RQAInferenceHF:
46
+ def __init__(
47
+ self,
48
+ model_path: str,
49
+ device: Optional[torch.device] = None,
50
+ max_length: int = 512,
51
+ issue_uncertain_margin: float = 0.05,
52
+ hidden_uncertain_margin: float = 0.05,
53
+ error_uncertain_margin: float = 0.05,
54
+ ):
55
+ self.model_path = model_path
56
+ self.device = device or torch.device(
57
+ "cuda" if torch.cuda.is_available() else "cpu"
58
+ )
59
+ self.max_length = int(max_length)
60
+ self.issue_uncertain_margin = float(issue_uncertain_margin)
61
+ self.hidden_uncertain_margin = float(hidden_uncertain_margin)
62
+ self.error_uncertain_margin = float(error_uncertain_margin)
63
+
64
+ self.model = RQAModelHF.from_pretrained(model_path).to(self.device).eval()
65
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
66
+
67
+ cfg = self.model.config
68
+ self.schema_version = str(getattr(cfg, "schema_version", "unknown"))
69
+ self.error_types = list(getattr(cfg, "error_types", []))
70
+ self.t_issue = float(getattr(cfg, "temperature_has_issue", 1.0))
71
+ self.t_hidden = float(getattr(cfg, "temperature_is_hidden", 1.0))
72
+ self.t_errors = list(
73
+ getattr(cfg, "temperature_errors", [1.0] * len(self.error_types))
74
+ )
75
+ self.th_issue = float(getattr(cfg, "threshold_has_issue", 0.5))
76
+ self.th_hidden = float(getattr(cfg, "threshold_is_hidden", 0.5))
77
+ self.th_error = float(getattr(cfg, "threshold_error", 0.5))
78
+ self.th_errors = list(
79
+ getattr(cfg, "threshold_errors", [self.th_error] * len(self.error_types))
80
+ )
81
+
82
+ calibration_path = _resolve_calibration_path(model_path)
83
+ if calibration_path:
84
+ calibration = torch.load(calibration_path, map_location="cpu")
85
+ calibration_error_types = calibration.get("error_types", None)
86
+ if calibration_error_types is not None:
87
+ if list(calibration_error_types) != self.error_types:
88
+ raise ValueError(
89
+ "Calibration artifact error_types mismatch with model.config.error_types."
90
+ )
91
+
92
+ self.schema_version = str(
93
+ calibration.get("schema_version", self.schema_version)
94
+ )
95
+ self.t_issue = float(
96
+ calibration.get("temperature_has_issue", self.t_issue)
97
+ )
98
+ self.t_hidden = float(
99
+ calibration.get("temperature_is_hidden", self.t_hidden)
100
+ )
101
+ self.t_errors = list(
102
+ calibration.get("temperature_errors", self.t_errors)
103
+ )
104
+ self.th_issue = float(
105
+ calibration.get("threshold_has_issue", self.th_issue)
106
+ )
107
+ self.th_hidden = float(
108
+ calibration.get("threshold_is_hidden", self.th_hidden)
109
+ )
110
+ self.th_error = float(
111
+ calibration.get("threshold_error", self.th_error)
112
+ )
113
+ self.th_errors = list(
114
+ calibration.get("threshold_errors", self.th_errors)
115
+ )
116
+
117
+ def _apply_temperature(
118
+ self,
119
+ issue_logits: torch.Tensor,
120
+ hidden_logits: torch.Tensor,
121
+ errors_logits: torch.Tensor,
122
+ ):
123
+ calibrated_issue = issue_logits / float(self.t_issue)
124
+ calibrated_hidden = hidden_logits / float(self.t_hidden)
125
+ calibrated_errors = errors_logits.clone()
126
+ for idx in range(calibrated_errors.size(1)):
127
+ temperature = float(self.t_errors[idx]) if idx < len(self.t_errors) else 1.0
128
+ calibrated_errors[:, idx] = calibrated_errors[:, idx] / temperature
129
+ return calibrated_issue, calibrated_hidden, calibrated_errors
130
+
131
+ @torch.no_grad()
132
+ def predict(
133
+ self,
134
+ text: str,
135
+ return_probs: bool = False,
136
+ threshold_issue: Optional[float] = None,
137
+ threshold_hidden: Optional[float] = None,
138
+ threshold_error: Optional[float] = None,
139
+ threshold_errors: Optional[List[float]] = None,
140
+ ) -> Dict[str, Any]:
141
+ issue_threshold = self.th_issue if threshold_issue is None else float(threshold_issue)
142
+ hidden_threshold = self.th_hidden if threshold_hidden is None else float(threshold_hidden)
143
+ error_threshold = self.th_error if threshold_error is None else float(threshold_error)
144
+ error_thresholds = self.th_errors if threshold_errors is None else list(threshold_errors)
145
+
146
+ encoded = self.tokenizer(
147
+ text,
148
+ truncation=True,
149
+ max_length=self.max_length,
150
+ padding="max_length",
151
+ return_tensors="pt",
152
+ )
153
+ input_ids = encoded["input_ids"].to(self.device)
154
+ attention_mask = encoded["attention_mask"].to(self.device)
155
+
156
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
157
+ issue_logits, hidden_logits, errors_logits = self._apply_temperature(
158
+ outputs["has_issue_logits"],
159
+ outputs["is_hidden_logits"],
160
+ outputs["errors_logits"],
161
+ )
162
+
163
+ issue_probability = float(torch.sigmoid(issue_logits).item())
164
+ has_issue = issue_probability >= issue_threshold
165
+
166
+ result: Dict[str, Any] = {
167
+ "schema_version": self.schema_version,
168
+ "text": text,
169
+ "class": None,
170
+ "status": "ok",
171
+ "review_required": False,
172
+ "has_logical_issue": bool(has_issue),
173
+ "has_issue_probability": issue_probability,
174
+ "threshold_has_issue": issue_threshold,
175
+ "temperature_has_issue": float(self.t_issue),
176
+ "is_hidden_problem": False,
177
+ "hidden_probability": None,
178
+ "threshold_is_hidden": hidden_threshold,
179
+ "temperature_is_hidden": float(self.t_hidden),
180
+ "errors": [],
181
+ "num_errors": 0,
182
+ "threshold_error": error_threshold,
183
+ "threshold_errors": error_thresholds,
184
+ "calibrated": (
185
+ abs(self.t_issue - 1.0) > 1e-6
186
+ or abs(self.t_hidden - 1.0) > 1e-6
187
+ or any(abs(float(t) - 1.0) > 1e-6 for t in self.t_errors)
188
+ ),
189
+ }
190
+
191
+ if abs(issue_probability - issue_threshold) <= self.issue_uncertain_margin:
192
+ result["status"] = "uncertain"
193
+ result["review_required"] = True
194
+
195
+ if not has_issue:
196
+ result["class"] = "logical"
197
+ if return_probs:
198
+ result["raw"] = {"p_issue": issue_probability}
199
+ return result
200
+
201
+ hidden_probability = float(torch.sigmoid(hidden_logits).item())
202
+ is_hidden = hidden_probability >= hidden_threshold
203
+ result["hidden_probability"] = hidden_probability
204
+ result["is_hidden_problem"] = bool(is_hidden)
205
+
206
+ if abs(hidden_probability - hidden_threshold) <= self.hidden_uncertain_margin:
207
+ result["status"] = "uncertain"
208
+ result["review_required"] = True
209
+
210
+ if is_hidden:
211
+ result["class"] = "hidden"
212
+ if return_probs:
213
+ result["raw"] = {
214
+ "p_issue": issue_probability,
215
+ "p_hidden": hidden_probability,
216
+ }
217
+ return result
218
+
219
+ error_probabilities = torch.sigmoid(errors_logits).cpu().numpy()[0]
220
+ detected_errors = []
221
+ for idx, error_type in enumerate(self.error_types):
222
+ probability = float(error_probabilities[idx])
223
+ threshold_i = float(
224
+ error_thresholds[idx] if idx < len(error_thresholds) else error_threshold
225
+ )
226
+ if abs(probability - threshold_i) <= self.error_uncertain_margin:
227
+ result["status"] = "uncertain"
228
+ result["review_required"] = True
229
+ if probability >= threshold_i:
230
+ detected_errors.append(
231
+ {
232
+ "type": error_type,
233
+ "probability": probability,
234
+ "threshold": threshold_i,
235
+ "temperature": float(self.t_errors[idx]) if idx < len(self.t_errors) else 1.0,
236
+ }
237
+ )
238
+
239
+ detected_errors.sort(key=lambda item: item["probability"], reverse=True)
240
+ result["class"] = "explicit"
241
+ result["errors"] = detected_errors
242
+ result["num_errors"] = len(detected_errors)
243
+
244
+ if return_probs:
245
+ result["error_probabilities"] = {
246
+ error_type: float(probability)
247
+ for error_type, probability in zip(self.error_types, error_probabilities)
248
+ }
249
+ result["raw"] = {
250
+ "p_issue": issue_probability,
251
+ "p_hidden": hidden_probability,
252
+ }
253
+
254
+ return result
255
+
256
+ def pretty_print(self, prediction: Dict[str, Any], use_russian_names: bool = True) -> None:
257
+ print("-" * 70)
258
+ print(
259
+ f"Class: {prediction['class']} | status={prediction['status']} "
260
+ f"| review_required={prediction['review_required']}"
261
+ )
262
+ print(
263
+ f"Issue: {prediction['has_logical_issue']} "
264
+ f"({prediction['has_issue_probability'] * 100:.2f}%) "
265
+ f"th={prediction['threshold_has_issue']:.3f}"
266
+ )
267
+ if prediction["hidden_probability"] is not None:
268
+ print(
269
+ f"Hidden: {prediction['is_hidden_problem']} "
270
+ f"({prediction['hidden_probability'] * 100:.2f}%) "
271
+ f"th={prediction['threshold_is_hidden']:.3f}"
272
+ )
273
+
274
+ if prediction["errors"]:
275
+ printable_errors = []
276
+ for item in prediction["errors"]:
277
+ label = (
278
+ ERROR_NAMES_RU.get(item["type"], item["type"])
279
+ if use_russian_names
280
+ else item["type"]
281
+ )
282
+ printable_errors.append((label, round(item["probability"], 3)))
283
+ print(f"Top errors: {printable_errors}")
284
+
285
+
286
+ class RQAJudge:
287
+ def __init__(
288
+ self,
289
+ model_name: str = "skatzR/RQA-R2",
290
+ device: Optional[torch.device] = None,
291
+ max_length: int = 512,
292
+ ):
293
+ self.runner = RQAInferenceHF(
294
+ model_path=model_name,
295
+ device=device,
296
+ max_length=max_length,
297
+ )
298
+
299
+ def infer(
300
+ self,
301
+ text: str,
302
+ issue_threshold: Optional[float] = None,
303
+ hidden_threshold: Optional[float] = None,
304
+ error_threshold: Optional[float] = None,
305
+ error_thresholds: Optional[List[float]] = None,
306
+ ) -> Dict[str, Any]:
307
+ prediction = self.runner.predict(
308
+ text=text,
309
+ return_probs=True,
310
+ threshold_issue=issue_threshold,
311
+ threshold_hidden=hidden_threshold,
312
+ threshold_error=error_threshold,
313
+ threshold_errors=error_thresholds,
314
+ )
315
+ return {
316
+ "text": text,
317
+ "class": prediction["class"],
318
+ "status": prediction["status"],
319
+ "review_required": prediction["review_required"],
320
+ "has_issue": prediction["has_logical_issue"],
321
+ "issue_probability": prediction["has_issue_probability"],
322
+ "hidden_problem": prediction["is_hidden_problem"],
323
+ "hidden_probability": prediction["hidden_probability"],
324
+ "errors": [
325
+ (item["type"], item["probability"])
326
+ for item in prediction["errors"]
327
+ ],
328
+ "num_errors": prediction["num_errors"],
329
+ "threshold_has_issue": prediction["threshold_has_issue"],
330
+ "threshold_is_hidden": prediction["threshold_is_hidden"],
331
+ "threshold_error": prediction["threshold_error"],
332
+ }
333
+
334
+ def pretty_print(self, result: Dict[str, Any], use_russian_names: bool = True) -> None:
335
+ converted = {
336
+ "class": result["class"],
337
+ "status": result["status"],
338
+ "review_required": result["review_required"],
339
+ "has_logical_issue": result["has_issue"],
340
+ "has_issue_probability": result["issue_probability"],
341
+ "threshold_has_issue": result["threshold_has_issue"],
342
+ "is_hidden_problem": result["hidden_problem"],
343
+ "hidden_probability": result["hidden_probability"],
344
+ "threshold_is_hidden": result["threshold_is_hidden"],
345
+ "errors": [
346
+ {
347
+ "type": error_type,
348
+ "probability": probability,
349
+ }
350
+ for error_type, probability in result["errors"]
351
+ ],
352
+ }
353
+ self.runner.pretty_print(converted, use_russian_names=use_russian_names)
354
+
355
+
356
+ __all__ = ["RQAInferenceHF", "RQAJudge", "ERROR_NAMES_RU"]
modeling_rqa.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoConfig, AutoModel, PreTrainedModel, PretrainedConfig
6
+
7
+
8
+ class RQAModelConfig(PretrainedConfig):
9
+ model_type = "rqa_v2_2"
10
+
11
+ def __init__(
12
+ self,
13
+ base_model_name: str = "FacebookAI/xlm-roberta-large",
14
+ encoder_config: Optional[Dict[str, Any]] = None,
15
+ error_types: Optional[List[str]] = None,
16
+ schema_version: str = "rqa.v2.2",
17
+ has_issue_projection_dim: int = 256,
18
+ hidden_projection_dim: int = 256,
19
+ errors_projection_dim: int = 512,
20
+ has_issue_dropout: float = 0.25,
21
+ hidden_dropout: float = 0.25,
22
+ errors_dropout: float = 0.30,
23
+ temperature_has_issue: float = 1.0,
24
+ temperature_is_hidden: float = 1.0,
25
+ temperature_errors: Optional[List[float]] = None,
26
+ threshold_has_issue: float = 0.5,
27
+ threshold_is_hidden: float = 0.5,
28
+ threshold_error: float = 0.5,
29
+ threshold_errors: Optional[List[float]] = None,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+
34
+ self.schema_version = str(schema_version)
35
+ self.base_model_name = base_model_name
36
+ self.encoder_config = encoder_config
37
+ self.error_types = list(error_types or [])
38
+ self.num_error_types = len(self.error_types)
39
+
40
+ self.has_issue_projection_dim = int(has_issue_projection_dim)
41
+ self.hidden_projection_dim = int(hidden_projection_dim)
42
+ self.errors_projection_dim = int(errors_projection_dim)
43
+
44
+ self.has_issue_dropout = float(has_issue_dropout)
45
+ self.hidden_dropout = float(hidden_dropout)
46
+ self.errors_dropout = float(errors_dropout)
47
+
48
+ self.temperature_has_issue = float(temperature_has_issue)
49
+ self.temperature_is_hidden = float(temperature_is_hidden)
50
+ self.temperature_errors = (
51
+ list(temperature_errors)
52
+ if temperature_errors is not None
53
+ else [1.0] * self.num_error_types
54
+ )
55
+
56
+ self.threshold_has_issue = float(threshold_has_issue)
57
+ self.threshold_is_hidden = float(threshold_is_hidden)
58
+ self.threshold_error = float(threshold_error)
59
+ self.threshold_errors = (
60
+ list(threshold_errors)
61
+ if threshold_errors is not None
62
+ else [self.threshold_error] * self.num_error_types
63
+ )
64
+
65
+ try:
66
+ self._experts_implementation = "eager"
67
+ self._experts_implementation_internal = "eager"
68
+ except Exception:
69
+ pass
70
+
71
+
72
+ def build_encoder_config_from_saved_dict(
73
+ encoder_config: Optional[Dict[str, Any]],
74
+ base_model_name: str,
75
+ ):
76
+ if encoder_config is None:
77
+ return AutoConfig.from_pretrained(base_model_name)
78
+
79
+ cfg_dict = dict(encoder_config)
80
+ model_type = cfg_dict.pop("model_type", None)
81
+ cfg_dict.pop("_name_or_path", None)
82
+
83
+ if model_type is not None:
84
+ try:
85
+ return AutoConfig.for_model(model_type, **cfg_dict)
86
+ except Exception:
87
+ pass
88
+
89
+ return AutoConfig.from_pretrained(base_model_name)
90
+
91
+
92
+ class MeanPooling(nn.Module):
93
+ def forward(
94
+ self,
95
+ last_hidden_state: torch.Tensor,
96
+ attention_mask: torch.Tensor,
97
+ ) -> torch.Tensor:
98
+ mask = attention_mask.unsqueeze(-1).float()
99
+ summed = torch.sum(last_hidden_state * mask, dim=1)
100
+ denom = torch.clamp(mask.sum(dim=1), min=1e-9)
101
+ return summed / denom
102
+
103
+
104
+ class RQAModelHF(PreTrainedModel):
105
+ config_class = RQAModelConfig
106
+ _supports_grouped_mm = False
107
+
108
+ def __init__(self, config: RQAModelConfig):
109
+ try:
110
+ config._experts_implementation = "eager"
111
+ config._experts_implementation_internal = "eager"
112
+ except Exception:
113
+ pass
114
+ super().__init__(config)
115
+
116
+ if config.encoder_config is None:
117
+ base_cfg = AutoConfig.from_pretrained(config.base_model_name)
118
+ config.encoder_config = base_cfg.to_dict()
119
+
120
+ enc_cfg = build_encoder_config_from_saved_dict(
121
+ encoder_config=config.encoder_config,
122
+ base_model_name=config.base_model_name,
123
+ )
124
+ self.encoder = AutoModel.from_config(enc_cfg)
125
+
126
+ hidden_size = self.encoder.config.hidden_size
127
+ self.pooler = MeanPooling()
128
+
129
+ self.has_issue_projection = nn.Sequential(
130
+ nn.Linear(hidden_size, config.has_issue_projection_dim),
131
+ nn.LayerNorm(config.has_issue_projection_dim),
132
+ nn.GELU(),
133
+ nn.Dropout(config.has_issue_dropout),
134
+ )
135
+ self.hidden_projection = nn.Sequential(
136
+ nn.Linear(hidden_size, config.hidden_projection_dim),
137
+ nn.LayerNorm(config.hidden_projection_dim),
138
+ nn.GELU(),
139
+ nn.Dropout(config.hidden_dropout),
140
+ )
141
+ self.errors_projection = nn.Sequential(
142
+ nn.Linear(hidden_size, config.errors_projection_dim),
143
+ nn.LayerNorm(config.errors_projection_dim),
144
+ nn.GELU(),
145
+ nn.Dropout(config.errors_dropout),
146
+ )
147
+
148
+ self.has_issue_head = nn.Linear(config.has_issue_projection_dim, 1)
149
+ self.is_hidden_head = nn.Linear(config.hidden_projection_dim, 1)
150
+ self.errors_head = nn.Linear(
151
+ config.errors_projection_dim,
152
+ config.num_error_types,
153
+ )
154
+
155
+ self.log_var_has_issue = nn.Parameter(torch.zeros(1))
156
+ self.log_var_is_hidden = nn.Parameter(torch.zeros(1))
157
+ self.log_var_errors = nn.Parameter(torch.zeros(1))
158
+ with torch.no_grad():
159
+ self.log_var_has_issue.clamp_(-5, 5)
160
+ self.log_var_is_hidden.clamp_(-5, 5)
161
+ self.log_var_errors.clamp_(-5, 5)
162
+
163
+ for module in [
164
+ self.has_issue_projection[0],
165
+ self.hidden_projection[0],
166
+ self.errors_projection[0],
167
+ self.has_issue_head,
168
+ self.is_hidden_head,
169
+ self.errors_head,
170
+ ]:
171
+ setattr(module, "_rqa_custom_init", True)
172
+
173
+ self.post_init()
174
+
175
+ def _init_weights(self, module):
176
+ if isinstance(module, nn.Linear) and getattr(module, "_rqa_custom_init", False):
177
+ nn.init.xavier_uniform_(module.weight)
178
+ if module.bias is not None:
179
+ nn.init.zeros_(module.bias)
180
+
181
+ def forward(
182
+ self,
183
+ input_ids: torch.Tensor,
184
+ attention_mask: torch.Tensor,
185
+ **kwargs,
186
+ ) -> Dict[str, torch.Tensor]:
187
+ outputs = self.encoder(
188
+ input_ids=input_ids,
189
+ attention_mask=attention_mask,
190
+ return_dict=True,
191
+ **kwargs,
192
+ )
193
+ pooled = self.pooler(outputs.last_hidden_state, attention_mask)
194
+
195
+ issue_features = self.has_issue_projection(pooled)
196
+ hidden_features = self.hidden_projection(pooled)
197
+ error_features = self.errors_projection(pooled)
198
+
199
+ return {
200
+ "has_issue_logits": self.has_issue_head(issue_features).squeeze(-1),
201
+ "is_hidden_logits": self.is_hidden_head(hidden_features).squeeze(-1),
202
+ "errors_logits": self.errors_head(error_features),
203
+ }
204
+
205
+
206
+ try:
207
+ AutoConfig.register("rqa_v2_2", RQAModelConfig)
208
+ except ValueError:
209
+ pass
210
+
211
+ try:
212
+ AutoModel.register(RQAModelConfig, RQAModelHF)
213
+ except ValueError:
214
+ pass