|
import re |
|
import logging |
|
import time |
|
from typing import List, Dict, Any, Optional, Tuple |
|
from random import sample, shuffle |
|
|
|
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser |
|
from langchain_core.runnables import Runnable, RunnablePassthrough |
|
from langchain_core.pydantic_v1 import Field, BaseModel as V1BaseModel |
|
|
|
from .config import settings |
|
from .graph_client import neo4j_client |
|
from .llm_interface import get_llm, invoke_llm |
|
from .prompts import ( |
|
CYPHER_GENERATION_PROMPT, CONCEPT_SELECTION_PROMPT, |
|
BINARY_GRADER_PROMPT, SCORE_GRADER_PROMPT |
|
) |
|
from .schemas import KeyIssue |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def extract_cypher(text: str) -> str: |
|
"""Extracts the first Cypher code block or returns the text itself.""" |
|
pattern = r"```(?:cypher)?\s*(.*?)\s*```" |
|
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) |
|
return match.group(1).strip() if match else text.strip() |
|
|
|
def format_doc_for_llm(doc: Dict[str, Any]) -> str: |
|
"""Formats a document dictionary into a string for LLM context.""" |
|
return "\n".join(f"**{key}**: {value}" for key, value in doc.items() if value) |
|
|
|
|
|
|
|
def generate_cypher_auto(question: str) -> str: |
|
"""Generates Cypher using the 'auto' method.""" |
|
logger.info("Generating Cypher using 'auto' method.") |
|
|
|
|
|
schema_info = "Schema not available." |
|
|
|
cypher_llm = get_llm(settings.main_llm_model) |
|
chain = ( |
|
{"question": RunnablePassthrough(), "schema": lambda x: schema_info} |
|
| CYPHER_GENERATION_PROMPT |
|
| cypher_llm |
|
| StrOutputParser() |
|
| extract_cypher |
|
) |
|
return invoke_llm(chain,question) |
|
|
|
def generate_cypher_guided(question: str, plan_step: int) -> str: |
|
"""Generates Cypher using the 'guided' method based on concepts.""" |
|
logger.info(f"Generating Cypher using 'guided' method for plan step {plan_step}.") |
|
try: |
|
concepts = neo4j_client.get_concepts() |
|
if not concepts: |
|
logger.warning("No concepts found in Neo4j for guided cypher generation.") |
|
return "" |
|
|
|
concept_llm = get_llm(settings.main_llm_model) |
|
concept_chain = ( |
|
CONCEPT_SELECTION_PROMPT |
|
| concept_llm |
|
| StrOutputParser() |
|
) |
|
selected_concept = invoke_llm(concept_chain,{ |
|
"question": question, |
|
"concepts": "\n".join(concepts) |
|
}).strip() |
|
|
|
logger.info(f"Concept selected by LLM: {selected_concept}") |
|
|
|
|
|
if selected_concept not in concepts: |
|
logger.warning(f"LLM selected concept '{selected_concept}' not in the known list. Attempting fallback or ignoring.") |
|
|
|
|
|
|
|
found_match = None |
|
for c in concepts: |
|
if selected_concept.lower() in c.lower(): |
|
found_match = c |
|
logger.info(f"Found potential match: '{found_match}'") |
|
break |
|
if not found_match: |
|
logger.error(f"Could not validate selected concept: {selected_concept}") |
|
return "" |
|
selected_concept = found_match |
|
|
|
|
|
|
|
|
|
if plan_step <= 1: |
|
target = "(ts:TechnicalSpecification)" |
|
fields = "ts.title, ts.scope, ts.description" |
|
elif plan_step == 2: |
|
target = "(rp:ResearchPaper)" |
|
fields = "rp.title, rp.abstract" |
|
else: |
|
target = "(n)" |
|
fields = "n.title, n.description" |
|
|
|
|
|
|
|
cypher = f"MATCH (c:Concept {{name: $conceptName}})-[:RELATED_TO]-{target} RETURN {fields}" |
|
|
|
|
|
|
|
escaped_concept = selected_concept.replace("'", "\\'") |
|
cypher = f"MATCH (c:Concept {{name: '{escaped_concept}'}})-[:RELATED_TO]-{target} RETURN {fields}" |
|
|
|
logger.info(f"Generated guided Cypher: {cypher}") |
|
return cypher |
|
|
|
except Exception as e: |
|
logger.error(f"Error during guided cypher generation: {e}", exc_info=True) |
|
time.sleep(60) |
|
return "" |
|
|
|
|
|
|
|
def retrieve_documents(cypher_query: str) -> List[Dict[str, Any]]: |
|
"""Retrieves documents from Neo4j using a Cypher query.""" |
|
if not cypher_query: |
|
logger.warning("Received empty Cypher query, skipping retrieval.") |
|
return [] |
|
logger.info(f"Retrieving documents with Cypher: {cypher_query} limit 10") |
|
try: |
|
|
|
raw_results = neo4j_client.query(cypher_query + " limit 10") |
|
|
|
processed_results = [] |
|
seen = set() |
|
for doc in raw_results: |
|
|
|
doc_items = frozenset(doc.items()) |
|
if doc_items not in seen: |
|
processed_results.append(doc) |
|
seen.add(doc_items) |
|
logger.info(f"Retrieved {len(processed_results)} unique documents.") |
|
return processed_results |
|
except (ConnectionError, ValueError, RuntimeError) as e: |
|
|
|
logger.error(f"Document retrieval failed: {e}") |
|
return [] |
|
|
|
|
|
|
|
|
|
class GradeDocumentsBinary(V1BaseModel): |
|
"""Binary score for relevance check.""" |
|
binary_score: str = Field(description="Relevant? 'yes' or 'no'") |
|
|
|
class GradeDocumentsScore(V1BaseModel): |
|
"""Score for relevance check.""" |
|
rationale: str = Field(description="Rationale for the score.") |
|
score: float = Field(description="Relevance score (0.0 to 1.0)") |
|
|
|
def evaluate_documents( |
|
docs: List[Dict[str, Any]], |
|
query: str |
|
) -> List[Dict[str, Any]]: |
|
"""Evaluates document relevance to a query using configured method.""" |
|
if not docs: |
|
return [] |
|
|
|
logger.info(f"Evaluating {len(docs)} documents for relevance to query: '{query}' using method: {settings.eval_method}") |
|
eval_llm = get_llm(settings.eval_llm_model) |
|
valid_docs_with_scores: List[Tuple[Dict[str, Any], float]] = [] |
|
|
|
|
|
|
|
|
|
|
|
if settings.eval_method == "binary": |
|
binary_grader = BINARY_GRADER_PROMPT | eval_llm | StrOutputParser() |
|
for doc in docs: |
|
formatted_doc = format_doc_for_llm(doc) |
|
if not formatted_doc.strip(): continue |
|
try: |
|
result = invoke_llm(binary_grader,{"question": query, "document": formatted_doc}) |
|
logger.debug(f"Binary grader result for doc '{doc.get('title', 'N/A')}': {result}") |
|
if result and 'yes' in result.lower(): |
|
valid_docs_with_scores.append((doc, 1.0)) |
|
except Exception as e: |
|
logger.warning(f"Binary grading failed for a document: {e}", exc_info=True) |
|
|
|
elif settings.eval_method == "score": |
|
|
|
score_grader = SCORE_GRADER_PROMPT | eval_llm | JsonOutputParser(pydantic_object=GradeDocumentsScore) |
|
for doc in docs: |
|
formatted_doc = format_doc_for_llm(doc) |
|
if not formatted_doc.strip(): continue |
|
try: |
|
result: GradeDocumentsScore = invoke_llm(score_grader,{"query": query, "document": formatted_doc}) |
|
logger.debug(f"Score grader result for doc '{doc.get('title', 'N/A')}': Score={result.score}, Rationale={result.rationale}") |
|
if result.score >= settings.eval_threshold: |
|
valid_docs_with_scores.append((doc, result.score)) |
|
except Exception as e: |
|
logger.warning(f"Score grading failed for a document: {e}", exc_info=True) |
|
|
|
|
|
|
|
if settings.eval_method == 'score': |
|
valid_docs_with_scores.sort(key=lambda item: item[1], reverse=True) |
|
|
|
|
|
final_docs = [doc for doc, score in valid_docs_with_scores[:settings.max_docs]] |
|
logger.info(f"Found {len(final_docs)} relevant documents after evaluation and filtering.") |
|
|
|
return final_docs |