| """Typed Pydantic models for the IndicScriptureQA environment.""" |
|
|
| from __future__ import annotations |
|
|
| from enum import Enum |
| from typing import Dict, List, Optional |
|
|
| from pydantic import BaseModel, Field |
|
|
|
|
| |
|
|
| class ActionType(str, Enum): |
| RETRIEVE = "RETRIEVE" |
| EDIT = "EDIT" |
| RESTRUCTURE = "RESTRUCTURE" |
| CITE = "CITE" |
| ACCEPT = "ACCEPT" |
| REJECT = "REJECT" |
|
|
|
|
| class Action(BaseModel): |
| action_type: ActionType |
| payload: Optional[str] = None |
|
|
|
|
| |
|
|
| class StructuralMeta(BaseModel): |
| """Describes the *expected* semantic structure of a correct answer.""" |
| required_terms: List[str] = Field( |
| default_factory=list, |
| description="Sanskrit / domain terms the answer must contain.", |
| ) |
| required_sections: List[str] = Field( |
| default_factory=list, |
| description="Conceptual aspects the answer should cover (order-independent).", |
| ) |
| expected_order: List[str] = Field( |
| default_factory=list, |
| description="Concepts that should appear in this logical sequence.", |
| ) |
| banned_terms: List[str] = Field( |
| default_factory=list, |
| description="Terms that indicate a common misconception if present.", |
| ) |
|
|
|
|
| |
|
|
| class Observation(BaseModel): |
| question: str |
| current_answer: str |
| retrieved_passages: List[str] = Field(default_factory=list) |
| current_citations: List[str] = Field(default_factory=list) |
| steps_remaining: int |
| task_name: str |
| feedback: Optional[str] = None |
| |
| structural_hints: List[str] = Field( |
| default_factory=list, |
| description="High-level hints about expected answer structure.", |
| ) |
|
|
|
|
| |
|
|
| class StepResult(BaseModel): |
| observation: Observation |
| reward: float = 0.0 |
| done: bool = False |
| info: dict = Field(default_factory=dict) |
|
|
|
|
| |
|
|
| class EnvState(BaseModel): |
| |
| question: str |
| current_answer: str |
| retrieved_passages: List[str] = Field(default_factory=list) |
| current_citations: List[str] = Field(default_factory=list) |
| steps_remaining: int = 0 |
| task_name: str = "" |
| feedback: Optional[str] = None |
| structural_hints: List[str] = Field(default_factory=list) |
|
|
| |
| original_answer: str = "" |
| ground_truth_answer: str = "" |
| ground_truth_citations: List[str] = Field(default_factory=list) |
| available_passages: List[str] = Field(default_factory=list) |
| answer_is_correct: bool = False |
| factual_is_correct: bool = False |
|
|
| |
| structural_meta: StructuralMeta = Field(default_factory=StructuralMeta) |
|
|
| |
| step_count: int = 0 |
| max_steps: int = 8 |
| done: bool = False |
| cumulative_reward: float = 0.0 |
| rewards: List[float] = Field(default_factory=list) |
| retrieval_count: int = 0 |
| edit_count: int = 0 |
| restructure_count: int = 0 |
|
|
| def to_observation(self) -> Observation: |
| return Observation( |
| question=self.question, |
| current_answer=self.current_answer, |
| retrieved_passages=list(self.retrieved_passages), |
| current_citations=list(self.current_citations), |
| steps_remaining=self.steps_remaining, |
| task_name=self.task_name, |
| feedback=self.feedback, |
| structural_hints=list(self.structural_hints), |
| ) |
|
|