RQA-R2 / inference.py
skatzR's picture
Update inference.py
21aa3b6 verified
# requirements
# Для inference в Colab достаточно этого стека.
!pip install transformers==4.48.3 tokenizers sentencepiece accelerate
# ============================================================
# RQA UX Inference — R2 Interactive Version
# Google Colab + CLI friendly
# ============================================================
import os
import json
import csv
import torch
from typing import List, Optional
from transformers import AutoTokenizer, AutoModel
# ============================================================
# Константы
# ============================================================
ERROR_TYPES = [
"false_causality",
"unsupported_claim",
"overgeneralization",
"missing_premise",
"contradiction",
"circular_reasoning",
]
ERROR_NAMES_RU = {
"false_causality": "Ложная причинно-следственная связь",
"unsupported_claim": "Неподкрепленное утверждение",
"overgeneralization": "Чрезмерное обобщение",
"missing_premise": "Отсутствующая предпосылка",
"contradiction": "Противоречие",
"circular_reasoning": "Круговое рассуждение",
}
# ============================================================
# RQA Judge
# ============================================================
class RQAJudge:
def __init__(self, model_name="skatzR/RQA-R2", device=None, max_length: int = 512):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.max_length = int(max_length)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
self.model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True
).to(self.device)
self.model.eval()
cfg = self.model.config
self.error_types = list(getattr(cfg, "error_types", ERROR_TYPES))
self.temp_issue = float(getattr(cfg, "temperature_has_issue", 1.0))
self.temp_hidden = float(getattr(cfg, "temperature_is_hidden", 1.0))
self.temp_errors = list(
getattr(cfg, "temperature_errors", [1.0] * len(self.error_types))
)
self.threshold_issue = float(getattr(cfg, "threshold_has_issue", 0.5))
self.threshold_hidden = float(getattr(cfg, "threshold_is_hidden", 0.5))
self.threshold_error = float(getattr(cfg, "threshold_error", 0.5))
self.threshold_errors = list(
getattr(cfg, "threshold_errors", [self.threshold_error] * len(self.error_types))
)
# ----------------------
# Core inference
# ----------------------
@torch.no_grad()
def infer(
self,
text: str,
issue_threshold: Optional[float] = None,
hidden_threshold: Optional[float] = None,
error_threshold: Optional[float] = None,
error_thresholds: Optional[List[float]] = None,
issue_uncertain_margin: float = 0.05,
hidden_uncertain_margin: float = 0.05,
error_uncertain_margin: float = 0.05,
):
issue_threshold = self.threshold_issue if issue_threshold is None else float(issue_threshold)
hidden_threshold = self.threshold_hidden if hidden_threshold is None else float(hidden_threshold)
error_threshold = self.threshold_error if error_threshold is None else float(error_threshold)
error_thresholds = self.threshold_errors if error_thresholds is None else list(error_thresholds)
inputs = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt"
).to(self.device)
outputs = self.model(**inputs)
# ----- has_issue -----
issue_logit = outputs["has_issue_logits"] / self.temp_issue
issue_prob = torch.sigmoid(issue_logit).item()
has_issue = issue_prob >= issue_threshold
result = {
"text": text,
"class": None, # logical / hidden / explicit
"status": "ok", # ok / uncertain
"review_required": False,
"has_issue": has_issue,
"issue_probability": issue_prob,
"hidden_problem": False,
"hidden_probability": None,
"errors": [],
"num_errors": 0,
"schema_version": getattr(self.model.config, "schema_version", "unknown"),
"threshold_issue": issue_threshold,
"threshold_hidden": hidden_threshold,
"threshold_error": error_threshold,
"threshold_errors": error_thresholds,
}
if abs(issue_prob - issue_threshold) <= issue_uncertain_margin:
result["status"] = "uncertain"
result["review_required"] = True
# ----- Gate 1: logical -----
if not has_issue:
result["class"] = "logical"
return result
# ----- hidden -----
hidden_logit = outputs["is_hidden_logits"] / self.temp_hidden
hidden_prob = torch.sigmoid(hidden_logit).item()
is_hidden = hidden_prob >= hidden_threshold
result["hidden_problem"] = is_hidden
result["hidden_probability"] = hidden_prob
if abs(hidden_prob - hidden_threshold) <= hidden_uncertain_margin:
result["status"] = "uncertain"
result["review_required"] = True
# ----- Gate 2: hidden -----
if is_hidden:
result["class"] = "hidden"
return result
# ----- explicit errors -----
raw_error_logits = outputs["errors_logits"][0].clone()
error_probs = {}
for i, logit in enumerate(raw_error_logits):
calibrated = logit / self.temp_errors[i]
prob = torch.sigmoid(calibrated).item()
error_probs[self.error_types[i]] = prob
explicit_errors = []
for i, err_name in enumerate(self.error_types):
prob = float(error_probs[err_name])
threshold_i = float(error_thresholds[i] if i < len(error_thresholds) else error_threshold)
if abs(prob - threshold_i) <= error_uncertain_margin:
result["status"] = "uncertain"
result["review_required"] = True
if prob >= threshold_i:
explicit_errors.append((err_name, prob))
explicit_errors.sort(key=lambda x: x[1], reverse=True)
result["class"] = "explicit"
result["errors"] = explicit_errors
result["num_errors"] = len(explicit_errors)
return result
# ============================================================
# UX output
# ============================================================
def pretty_print(self, r):
print("\n" + "=" * 72)
print("📄 Текст:")
print(r["text"])
print(
f"\n🔎 Обнаружена проблема: {'ДА' if r['has_issue'] else 'НЕТ'} "
f"({r['issue_probability'] * 100:.2f}%)"
)
print(f"🧠 Класс: {r['class']}")
if r["status"] == "uncertain":
print("⚠️ Пограничный случай: review recommended")
if r["hidden_probability"] is not None:
print(
f"🟡 Hidden-проблема: {'ДА' if r['hidden_problem'] else 'НЕТ'} "
f"({r['hidden_probability'] * 100:.2f}%)"
)
if r["errors"]:
print("\n❌ Явные логические ошибки:")
for name, prob in r["errors"]:
print(f" • {ERROR_NAMES_RU.get(name, name)}{prob * 100:.2f}%")
else:
print("\n✅ Явных логических ошибок не обнаружено")
print("=" * 72)
# ============================================================
# Loaders
# ============================================================
def load_texts_from_file(path: str) -> List[str]:
ext = os.path.splitext(path)[1].lower()
if ext == ".txt":
with open(path, encoding="utf-8") as f:
return [line.strip() for line in f if line.strip()]
if ext == ".csv":
with open(path, encoding="utf-8") as f:
reader = csv.DictReader(f)
return [row["text"] for row in reader if row.get("text")]
if ext == ".json":
with open(path, encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
if all(isinstance(item, str) for item in data):
return data
texts = []
for item in data:
if isinstance(item, dict) and "text" in item:
texts.append(str(item["text"]))
return texts
raise ValueError("Неподдерживаемый формат файла")
# ============================================================
# Interactive CLI Interface
# ============================================================
class InteractiveCLI:
def __init__(self, model_name="skatzR/RQA-R2"):
self.judge = RQAJudge(model_name=model_name)
def clear_screen(self):
print("\n" * 2)
def show_mode_menu(self):
self.clear_screen()
print("=" * 60)
print("🤖 RQA-R2 — АНАЛИЗ ЛОГИЧЕСКИХ ОШИБОК")
print("=" * 60)
print("\nВыберите режим работы:")
print("1. 📝 Одиночный ввод (одна фраза для анализа)")
print("2. 📄 Множественный ввод (несколько фраз, каждая с новой строки)")
print("3. 📂 Загрузка из файла (.txt, .csv, .json)")
print("\nНажмите Enter без ввода для выхода.")
print("-" * 60)
def process_single_mode(self):
self.clear_screen()
print("[📝 РЕЖИМ: ОДИНОЧНЫЙ ВВОД]")
print("Введите текст для анализа:")
print("(Нажмите Enter без ввода для возврата в меню)")
print("-" * 40)
text = input("> ").strip()
if not text:
return True
result = self.judge.infer(text)
self.judge.pretty_print(result)
print("\n" + "-" * 40)
input("Нажмите Enter для продолжения...")
return False
def process_multiline_mode(self):
self.clear_screen()
print("[📄 РЕЖИМ: МНОЖЕСТВЕННЫЙ ВВОД]")
print("Введите тексты для анализа (каждый с новой строки).")
print("Оставьте строку пустой для завершения ввода.")
print("(Нажмите Enter без ввода для возврата в меню)")
print("-" * 40)
texts = []
print("Ввод текстов:")
while True:
line = input("> ").strip()
if not line:
if not texts:
return True
break
texts.append(line)
self.clear_screen()
print(f"[📄 РЕЖИМ: МНОЖЕСТВЕННЫЙ ВВОД] — найдено {len(texts)} текстов")
print("-" * 40)
for i, text in enumerate(texts, 1):
print(f"\n🔍 Текст #{i}:")
result = self.judge.infer(text)
self.judge.pretty_print(result)
print("\n" + "=" * 60)
input("Нажмите Enter для продолжения...")
return False
def process_file_mode(self):
self.clear_screen()
print("[📂 РЕЖИМ: ЗАГРУЗКА ИЗ ФАЙЛА]")
print("Поддерживаемые форматы: .txt, .csv, .json")
print("Укажите путь к файлу:")
print("(Нажмите Enter без ввода для возврата в меню)")
print("-" * 40)
file_path = input("Путь к файлу> ").strip()
if not file_path:
return True
try:
if not os.path.exists(file_path):
print(f"\n❌ Ошибка: Файл '{file_path}' не найден!")
input("\nНажмите Enter для продолжения...")
return False
texts = load_texts_from_file(file_path)
if not texts:
print(f"\n⚠️ Файл '{file_path}' пуст или не содержит текстов!")
input("\nНажмите Enter для продолжения...")
return False
self.clear_screen()
print(f"[📂 РЕЖИМ: ЗАГРУЗКА ИЗ ФАЙЛА] — загружено {len(texts)} текстов")
print(f"Файл: {file_path}")
print("-" * 40)
for i, text in enumerate(texts, 1):
print(f"\n🔍 Текст #{i}:")
result = self.judge.infer(text)
self.judge.pretty_print(result)
print("\n" + "=" * 60)
input("Нажмите Enter для продолжения...")
except Exception as e:
print(f"\n❌ Ошибка при обработке файла: {str(e)}")
input("\nНажмите Enter для продолжения...")
return False
def run_interactive(self):
current_mode = None
while True:
if not current_mode:
self.show_mode_menu()
choice = input("Ваш выбор (1-3)> ").strip()
if not choice:
print("\n👋 Выход из программы...")
break
if choice == "1":
current_mode = "single"
elif choice == "2":
current_mode = "multiline"
elif choice == "3":
current_mode = "file"
else:
print("\n❌ Неверный выбор! Попробуйте снова.")
input("Нажмите Enter для продолжения...")
continue
should_return_to_menu = False
if current_mode == "single":
should_return_to_menu = self.process_single_mode()
elif current_mode == "multiline":
should_return_to_menu = self.process_multiline_mode()
elif current_mode == "file":
should_return_to_menu = self.process_file_mode()
if should_return_to_menu:
current_mode = None
# ============================================================
# Точка входа
# ============================================================
def main():
cli = InteractiveCLI()
cli.run_interactive()
# ============================================================
# Запуск
# ============================================================
if __name__ == "__main__":
main()