RQA-R1 / inference.py
skatzR's picture
Update inference.py
7f63d50 verified
# ============================================================
# RQA UX Inference — IMPROVED INTERACTIVE VERSION
# Google Colab + CLI friendly
# ============================================================
import os
import sys
import json
import csv
import torch
from typing import List
from transformers import AutoTokenizer, AutoModel
# ============================================================
# Константы
# ============================================================
ERROR_TYPES = [
"false_causality",
"unsupported_claim",
"overgeneralization",
"missing_premise",
"contradiction",
"circular_reasoning",
]
ERROR_NAMES_RU = {
"false_causality": "Ложная причинно-следственная связь",
"unsupported_claim": "Неподкреплённое утверждение",
"overgeneralization": "Чрезмерное обобщение",
"missing_premise": "Отсутствующая предпосылка",
"contradiction": "Противоречие",
"circular_reasoning": "Круговое рассуждение",
}
ERROR_THRESHOLDS = {
"false_causality": 0.55,
"unsupported_claim": 0.55,
"overgeneralization": 0.60,
"missing_premise": 0.80, # диагностический
"contradiction": 0.60,
"circular_reasoning": 0.60,
}
# ============================================================
# RQA Judge
# ============================================================
class RQAJudge:
def __init__(self, model_name="skatzR/RQA-R1", device=None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
self.model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True
).to(self.device)
self.model.eval()
cfg = self.model.config
self.temp_issue = float(cfg.temperature_has_issue)
self.temp_errors = list(cfg.temperature_errors)
# ----------------------
# Core inference
# ----------------------
@torch.no_grad()
def infer(
self,
text: str,
issue_threshold: float = 0.6,
disagreement_threshold: float = 0.4,
):
inputs = self.tokenizer(
text,
truncation=True,
max_length=512,
padding="max_length",
return_tensors="pt"
).to(self.device)
outputs = self.model(**inputs)
# ----- has_issue -----
issue_logit = outputs["has_issue_logits"] / self.temp_issue
issue_prob = torch.sigmoid(issue_logit).item()
has_issue = issue_prob >= issue_threshold
# ----- errors -----
raw_error_logits = outputs["errors_logits"][0]
error_probs = {}
for i, logit in enumerate(raw_error_logits):
calibrated = logit / self.temp_errors[i]
prob = torch.sigmoid(calibrated).item()
error_probs[ERROR_TYPES[i]] = prob
# ----- disagreement -----
p_any_error = 1.0
for p in error_probs.values():
p_any_error *= (1.0 - p)
p_any_error = 1.0 - p_any_error
disagreement = abs(issue_prob - p_any_error)
# ----- decision logic -----
explicit_errors = []
hidden_problem = False
for err, prob in error_probs.items():
if prob >= ERROR_THRESHOLDS[err]:
if err == "missing_premise":
hidden_problem = True
else:
explicit_errors.append((err, prob))
explicit_errors.sort(key=lambda x: x[1], reverse=True)
# бинарная голова доминирует
if not has_issue:
explicit_errors = []
borderline = (
not has_issue and hidden_problem and disagreement >= disagreement_threshold
)
return {
"text": text,
"has_issue": has_issue,
"issue_probability": issue_prob,
"errors": explicit_errors,
"hidden_problem": hidden_problem,
"borderline": borderline,
"disagreement": disagreement,
}
# ============================================================
# UX output
# ============================================================
def pretty_print(self, r):
print("\n" + "=" * 72)
print("📄 Текст:")
print(r["text"])
print(f"\n🔎 Обнаружена проблема: {'ДА' if r['has_issue'] else 'НЕТ'} "
f"({r['issue_probability']*100:.2f}%)")
if r["borderline"]:
print("⚠️ Пограничный случай: аргументативный текст")
if r["hidden_problem"]:
print("🟡 Скрытая проблема: возможны неявные предпосылки")
if r["errors"]:
print("\n❌ Явные логические ошибки:")
for name, prob in r["errors"]:
print(f" • {ERROR_NAMES_RU[name]}{prob*100:.2f}%")
else:
print("\n✅ Явных логических ошибок не обнаружено")
print(f"\n📊 Disagreement: {r['disagreement']:.3f}")
print("=" * 72)
# ============================================================
# Loaders
# ============================================================
def load_texts_from_file(path: str) -> List[str]:
ext = os.path.splitext(path)[1].lower()
if ext == ".txt":
with open(path, encoding="utf-8") as f:
return [l.strip() for l in f if l.strip()]
if ext == ".csv":
with open(path, encoding="utf-8") as f:
reader = csv.DictReader(f)
return [row["text"] for row in reader]
if ext == ".json":
with open(path, encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
return data
raise ValueError("Неподдерживаемый формат файла")
# ============================================================
# Interactive CLI Interface
# ============================================================
class InteractiveCLI:
def __init__(self):
self.judge = RQAJudge()
self.mode_stack = []
def clear_screen(self):
"""Очистка экрана для Google Colab"""
print("\n" * 2)
def show_mode_menu(self):
"""Показать меню выбора режима"""
self.clear_screen()
print("=" * 60)
print("🤖 RQA — АНАЛИЗ ЛОГИЧЕСКИХ ОШИБОК")
print("=" * 60)
print("\nВыберите режим работы:")
print("1. 📝 Одиночный ввод (одна фраза для анализа)")
print("2. 📄 Множественный ввод (несколько фраз, каждая с новой строки)")
print("3. 📂 Загрузка из файла (.txt, .csv, .json)")
print("\nНажмите Enter без ввода для выхода.")
print("-" * 60)
def process_single_mode(self):
"""Обработка одиночного режима"""
self.clear_screen()
print("[📝 РЕЖИМ: ОДИНОЧНЫЙ ВВОД]")
print("Введите текст для анализа:")
print("(Нажмите Enter без ввода для возврата в меню)")
print("-" * 40)
text = input("> ").strip()
if not text:
return True # Возврат в меню
result = self.judge.infer(text)
self.judge.pretty_print(result)
print("\n" + "-" * 40)
input("Нажмите Enter для продолжения...")
return False # Остаемся в том же режиме
def process_multiline_mode(self):
"""Обработка режима множественного ввода"""
self.clear_screen()
print("[📄 РЕЖИМ: МНОЖЕСТВЕННЫЙ ВВОД]")
print("Введите тексты для анализа (каждый с новой строки).")
print("Оставьте строку пустой для завершения ввода.")
print("(Нажмите Enter без ввода для возврата в меню)")
print("-" * 40)
texts = []
print("Ввод текстов:")
while True:
line = input("> ").strip()
if not line:
if not texts: # Пустой ввод сразу - возврат в меню
return True
break # Завершение ввода
texts.append(line)
if texts:
self.clear_screen()
print(f"[📄 РЕЖИМ: МНОЖЕСТВЕННЫЙ ВВОД] — найдено {len(texts)} текстов")
print("-" * 40)
for i, text in enumerate(texts, 1):
print(f"\n🔍 Текст #{i}:")
result = self.judge.infer(text)
self.judge.pretty_print(result)
print("\n" + "=" * 60)
input("Нажмите Enter для продолжения...")
return False # Остаемся в том же режиме
def process_file_mode(self):
"""Обработка режима загрузки из файла"""
self.clear_screen()
print("[📂 РЕЖИМ: ЗАГРУЗКА ИЗ ФАЙЛА]")
print("Поддерживаемые форматы: .txt, .csv, .json")
print("Укажите путь к файлу:")
print("(Нажмите Enter без ввода для возврата в меню)")
print("-" * 40)
file_path = input("Путь к файлу> ").strip()
if not file_path:
return True # Возврат в меню
try:
# Проверка существования файла
if not os.path.exists(file_path):
print(f"\n❌ Ошибка: Файл '{file_path}' не найден!")
input("\nНажмите Enter для продолжения...")
return False # Остаемся в том же режиме
# Загрузка текстов
texts = load_texts_from_file(file_path)
if not texts:
print(f"\n⚠️ Файл '{file_path}' пуст или не содержит текстов!")
input("\nНажмите Enter для продолжения...")
return False # Остаемся в том же режиме
# Обработка текстов
self.clear_screen()
print(f"[📂 РЕЖИМ: ЗАГРУЗКА ИЗ ФАЙЛА] — загружено {len(texts)} текстов")
print(f"Файл: {file_path}")
print("-" * 40)
for i, text in enumerate(texts, 1):
print(f"\n🔍 Текст #{i}:")
result = self.judge.infer(text)
self.judge.pretty_print(result)
print("\n" + "=" * 60)
input("Нажмите Enter для продолжения...")
except Exception as e:
print(f"\n❌ Ошибка при обработке файла: {str(e)}")
input("\nНажмите Enter для продолжения...")
return False # Остаемся в том же режиме
def run_interactive(self):
"""Основной цикл интерактивного интерфейса"""
current_mode = None
while True:
# Если нет текущего режима, показываем главное меню
if not current_mode:
self.show_mode_menu()
choice = input("Ваш выбор (1-3)> ").strip()
if not choice: # Пустой ввод - выход
print("\n👋 Выход из программы...")
break
if choice == "1":
current_mode = "single"
elif choice == "2":
current_mode = "multiline"
elif choice == "3":
current_mode = "file"
else:
print("\n❌ Неверный выбор! Попробуйте снова.")
input("Нажмите Enter для продолжения...")
continue
# Обработка текущего режима
should_return_to_menu = False
if current_mode == "single":
should_return_to_menu = self.process_single_mode()
elif current_mode == "multiline":
should_return_to_menu = self.process_multiline_mode()
elif current_mode == "file":
should_return_to_menu = self.process_file_mode()
# Возврат в меню при необходимости
if should_return_to_menu:
current_mode = None
# ============================================================
# Точка входа
# ============================================================
def main():
"""Основная функция - запускает интерактивный интерфейс"""
cli = InteractiveCLI()
cli.run_interactive()
# ============================================================
# Запуск
# ============================================================
if __name__ == "__main__":
main()