|
"""LLM Chains for evaluating question answering.""" |
|
|
|
from __future__ import annotations |
|
|
|
import re |
|
import string |
|
from typing import Any, List, Optional, Sequence, Tuple |
|
|
|
from langchain_core.callbacks.manager import Callbacks |
|
from langchain_core.language_models import BaseLanguageModel |
|
from langchain_core.prompts import PromptTemplate |
|
from pydantic import ConfigDict |
|
|
|
from langchain.chains.llm import LLMChain |
|
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT |
|
from langchain.evaluation.schema import LLMEvalChain, StringEvaluator |
|
from langchain.schema import RUN_KEY |
|
|
|
|
|
def _get_score(text: str) -> Optional[Tuple[str, int]]: |
|
match = re.search(r"grade:\s*(correct|incorrect)", text.strip(), re.IGNORECASE) |
|
if match: |
|
if match.group(1).upper() == "CORRECT": |
|
return "CORRECT", 1 |
|
elif match.group(1).upper() == "INCORRECT": |
|
return "INCORRECT", 0 |
|
try: |
|
first_word = ( |
|
text.strip().split()[0].translate(str.maketrans("", "", string.punctuation)) |
|
) |
|
if first_word.upper() == "CORRECT": |
|
return "CORRECT", 1 |
|
elif first_word.upper() == "INCORRECT": |
|
return "INCORRECT", 0 |
|
last_word = ( |
|
text.strip() |
|
.split()[-1] |
|
.translate(str.maketrans("", "", string.punctuation)) |
|
) |
|
if last_word.upper() == "CORRECT": |
|
return "CORRECT", 1 |
|
elif last_word.upper() == "INCORRECT": |
|
return "INCORRECT", 0 |
|
except IndexError: |
|
pass |
|
return None |
|
|
|
|
|
def _parse_string_eval_output(text: str) -> dict: |
|
"""Parse the output text. |
|
|
|
Args: |
|
text (str): The output text to parse. |
|
|
|
Returns: |
|
Any: The parsed output. |
|
""" |
|
reasoning = text.strip() |
|
parsed_scores = _get_score(reasoning) |
|
if parsed_scores is None: |
|
value, score = None, None |
|
else: |
|
value, score = parsed_scores |
|
return { |
|
"reasoning": reasoning, |
|
"value": value, |
|
"score": score, |
|
} |
|
|
|
|
|
class QAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): |
|
"""LLM Chain for evaluating question answering.""" |
|
|
|
output_key: str = "results" |
|
|
|
model_config = ConfigDict( |
|
extra="ignore", |
|
) |
|
|
|
@classmethod |
|
def is_lc_serializable(cls) -> bool: |
|
return False |
|
|
|
@property |
|
def evaluation_name(self) -> str: |
|
return "correctness" |
|
|
|
@property |
|
def requires_reference(self) -> bool: |
|
return True |
|
|
|
@property |
|
def requires_input(self) -> bool: |
|
return True |
|
|
|
@classmethod |
|
def from_llm( |
|
cls, |
|
llm: BaseLanguageModel, |
|
prompt: Optional[PromptTemplate] = None, |
|
**kwargs: Any, |
|
) -> QAEvalChain: |
|
"""Load QA Eval Chain from LLM. |
|
|
|
Args: |
|
llm (BaseLanguageModel): the base language model to use. |
|
|
|
prompt (PromptTemplate): A prompt template containing the input_variables: |
|
'input', 'answer' and 'result' that will be used as the prompt |
|
for evaluation. |
|
Defaults to PROMPT. |
|
|
|
**kwargs: additional keyword arguments. |
|
|
|
Returns: |
|
QAEvalChain: the loaded QA eval chain. |
|
""" |
|
prompt = prompt or PROMPT |
|
expected_input_vars = {"query", "answer", "result"} |
|
if expected_input_vars != set(prompt.input_variables): |
|
raise ValueError( |
|
f"Input variables should be {expected_input_vars}, " |
|
f"but got {prompt.input_variables}" |
|
) |
|
return cls(llm=llm, prompt=prompt, **kwargs) |
|
|
|
def evaluate( |
|
self, |
|
examples: Sequence[dict], |
|
predictions: Sequence[dict], |
|
question_key: str = "query", |
|
answer_key: str = "answer", |
|
prediction_key: str = "result", |
|
*, |
|
callbacks: Callbacks = None, |
|
) -> List[dict]: |
|
"""Evaluate question answering examples and predictions.""" |
|
inputs = [ |
|
{ |
|
"query": example[question_key], |
|
"answer": example[answer_key], |
|
"result": predictions[i][prediction_key], |
|
} |
|
for i, example in enumerate(examples) |
|
] |
|
|
|
return self.apply(inputs, callbacks=callbacks) |
|
|
|
def _prepare_output(self, result: dict) -> dict: |
|
parsed_result = _parse_string_eval_output(result[self.output_key]) |
|
if RUN_KEY in result: |
|
parsed_result[RUN_KEY] = result[RUN_KEY] |
|
return parsed_result |
|
|
|
def _evaluate_strings( |
|
self, |
|
*, |
|
prediction: str, |
|
reference: Optional[str] = None, |
|
input: Optional[str] = None, |
|
callbacks: Callbacks = None, |
|
include_run_info: bool = False, |
|
**kwargs: Any, |
|
) -> dict: |
|
"""Evaluate Chain or LLM output, based on optional input and label. |
|
|
|
Args: |
|
prediction (str): the LLM or chain prediction to evaluate. |
|
reference (Optional[str], optional): the reference label |
|
to evaluate against. |
|
input (Optional[str], optional): the input to consider during evaluation |
|
callbacks (Callbacks, optional): the callbacks to use for tracing. |
|
include_run_info (bool, optional): whether to include run info in the |
|
returned results. |
|
**kwargs: additional keyword arguments, including callbacks, tags, etc. |
|
Returns: |
|
dict: The evaluation results containing the score or value. |
|
""" |
|
result = self( |
|
{ |
|
"query": input, |
|
"answer": reference, |
|
"result": prediction, |
|
}, |
|
callbacks=callbacks, |
|
include_run_info=include_run_info, |
|
) |
|
return self._prepare_output(result) |
|
|
|
async def _aevaluate_strings( |
|
self, |
|
*, |
|
prediction: str, |
|
reference: Optional[str] = None, |
|
input: Optional[str] = None, |
|
callbacks: Callbacks = None, |
|
include_run_info: bool = False, |
|
**kwargs: Any, |
|
) -> dict: |
|
result = await self.acall( |
|
inputs={"query": input, "answer": reference, "result": prediction}, |
|
callbacks=callbacks, |
|
include_run_info=include_run_info, |
|
) |
|
return self._prepare_output(result) |
|
|
|
|
|
class ContextQAEvalChain(LLMChain, StringEvaluator, LLMEvalChain): |
|
"""LLM Chain for evaluating QA w/o GT based on context""" |
|
|
|
@classmethod |
|
def is_lc_serializable(cls) -> bool: |
|
return False |
|
|
|
@property |
|
def requires_reference(self) -> bool: |
|
"""Whether the chain requires a reference string.""" |
|
return True |
|
|
|
@property |
|
def requires_input(self) -> bool: |
|
"""Whether the chain requires an input string.""" |
|
return True |
|
|
|
model_config = ConfigDict( |
|
extra="ignore", |
|
) |
|
|
|
@classmethod |
|
def _validate_input_vars(cls, prompt: PromptTemplate) -> None: |
|
expected_input_vars = {"query", "context", "result"} |
|
if expected_input_vars != set(prompt.input_variables): |
|
raise ValueError( |
|
f"Input variables should be {expected_input_vars}, " |
|
f"but got {prompt.input_variables}" |
|
) |
|
|
|
@property |
|
def evaluation_name(self) -> str: |
|
return "Contextual Accuracy" |
|
|
|
@classmethod |
|
def from_llm( |
|
cls, |
|
llm: BaseLanguageModel, |
|
prompt: Optional[PromptTemplate] = None, |
|
**kwargs: Any, |
|
) -> ContextQAEvalChain: |
|
"""Load QA Eval Chain from LLM. |
|
|
|
Args: |
|
llm (BaseLanguageModel): the base language model to use. |
|
|
|
prompt (PromptTemplate): A prompt template containing the input_variables: |
|
'query', 'context' and 'result' that will be used as the prompt |
|
for evaluation. |
|
Defaults to PROMPT. |
|
|
|
**kwargs: additional keyword arguments. |
|
|
|
Returns: |
|
ContextQAEvalChain: the loaded QA eval chain. |
|
""" |
|
prompt = prompt or CONTEXT_PROMPT |
|
cls._validate_input_vars(prompt) |
|
return cls(llm=llm, prompt=prompt, **kwargs) |
|
|
|
def evaluate( |
|
self, |
|
examples: List[dict], |
|
predictions: List[dict], |
|
question_key: str = "query", |
|
context_key: str = "context", |
|
prediction_key: str = "result", |
|
*, |
|
callbacks: Callbacks = None, |
|
) -> List[dict]: |
|
"""Evaluate question answering examples and predictions.""" |
|
inputs = [ |
|
{ |
|
"query": example[question_key], |
|
"context": example[context_key], |
|
"result": predictions[i][prediction_key], |
|
} |
|
for i, example in enumerate(examples) |
|
] |
|
|
|
return self.apply(inputs, callbacks=callbacks) |
|
|
|
def _prepare_output(self, result: dict) -> dict: |
|
parsed_result = _parse_string_eval_output(result[self.output_key]) |
|
if RUN_KEY in result: |
|
parsed_result[RUN_KEY] = result[RUN_KEY] |
|
return parsed_result |
|
|
|
def _evaluate_strings( |
|
self, |
|
*, |
|
prediction: str, |
|
reference: Optional[str] = None, |
|
input: Optional[str] = None, |
|
callbacks: Callbacks = None, |
|
include_run_info: bool = False, |
|
**kwargs: Any, |
|
) -> dict: |
|
result = self( |
|
{ |
|
"query": input, |
|
"context": reference, |
|
"result": prediction, |
|
}, |
|
callbacks=callbacks, |
|
include_run_info=include_run_info, |
|
) |
|
return self._prepare_output(result) |
|
|
|
async def _aevaluate_strings( |
|
self, |
|
*, |
|
prediction: str, |
|
reference: Optional[str] = None, |
|
input: Optional[str] = None, |
|
callbacks: Callbacks = None, |
|
include_run_info: bool = False, |
|
**kwargs: Any, |
|
) -> dict: |
|
result = await self.acall( |
|
inputs={"query": input, "context": reference, "result": prediction}, |
|
callbacks=callbacks, |
|
include_run_info=include_run_info, |
|
) |
|
return self._prepare_output(result) |
|
|
|
|
|
class CotQAEvalChain(ContextQAEvalChain): |
|
"""LLM Chain for evaluating QA using chain of thought reasoning.""" |
|
|
|
@classmethod |
|
def is_lc_serializable(cls) -> bool: |
|
return False |
|
|
|
@property |
|
def evaluation_name(self) -> str: |
|
return "COT Contextual Accuracy" |
|
|
|
@classmethod |
|
def from_llm( |
|
cls, |
|
llm: BaseLanguageModel, |
|
prompt: Optional[PromptTemplate] = None, |
|
**kwargs: Any, |
|
) -> CotQAEvalChain: |
|
"""Load QA Eval Chain from LLM.""" |
|
prompt = prompt or COT_PROMPT |
|
cls._validate_input_vars(prompt) |
|
return cls(llm=llm, prompt=prompt, **kwargs) |
|
|