Spaces:
Runtime error
Runtime error
"""Select and order examples based on ngram overlap score (sentence_bleu score). | |
https://www.nltk.org/_modules/nltk/translate/bleu_score.html | |
https://aclanthology.org/P02-1040.pdf | |
""" | |
from typing import Dict, List | |
import numpy as np | |
from pydantic import BaseModel, root_validator | |
from langchain.prompts.example_selector.base import BaseExampleSelector | |
from langchain.prompts.prompt import PromptTemplate | |
def ngram_overlap_score(source: List[str], example: List[str]) -> float: | |
"""Compute ngram overlap score of source and example as sentence_bleu score. | |
Use sentence_bleu with method1 smoothing function and auto reweighting. | |
Return float value between 0.0 and 1.0 inclusive. | |
https://www.nltk.org/_modules/nltk/translate/bleu_score.html | |
https://aclanthology.org/P02-1040.pdf | |
""" | |
from nltk.translate.bleu_score import ( # type: ignore | |
SmoothingFunction, | |
sentence_bleu, | |
) | |
hypotheses = source[0].split() | |
references = [s.split() for s in example] | |
return float( | |
sentence_bleu( | |
references, | |
hypotheses, | |
smoothing_function=SmoothingFunction().method1, | |
auto_reweigh=True, | |
) | |
) | |
class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel): | |
"""Select and order examples based on ngram overlap score (sentence_bleu score). | |
https://www.nltk.org/_modules/nltk/translate/bleu_score.html | |
https://aclanthology.org/P02-1040.pdf | |
""" | |
examples: List[dict] | |
"""A list of the examples that the prompt template expects.""" | |
example_prompt: PromptTemplate | |
"""Prompt template used to format the examples.""" | |
threshold: float = -1.0 | |
"""Threshold at which algorithm stops. Set to -1.0 by default. | |
For negative threshold: | |
select_examples sorts examples by ngram_overlap_score, but excludes none. | |
For threshold greater than 1.0: | |
select_examples excludes all examples, and returns an empty list. | |
For threshold equal to 0.0: | |
select_examples sorts examples by ngram_overlap_score, | |
and excludes examples with no ngram overlap with input. | |
""" | |
def check_dependencies(cls, values: Dict) -> Dict: | |
"""Check that valid dependencies exist.""" | |
try: | |
from nltk.translate.bleu_score import ( # noqa: disable=F401 | |
SmoothingFunction, | |
sentence_bleu, | |
) | |
except ImportError as e: | |
raise ValueError( | |
"Not all the correct dependencies for this ExampleSelect exist" | |
) from e | |
return values | |
def add_example(self, example: Dict[str, str]) -> None: | |
"""Add new example to list.""" | |
self.examples.append(example) | |
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: | |
"""Return list of examples sorted by ngram_overlap_score with input. | |
Descending order. | |
Excludes any examples with ngram_overlap_score less than or equal to threshold. | |
""" | |
inputs = list(input_variables.values()) | |
examples = [] | |
k = len(self.examples) | |
score = [0.0] * k | |
first_prompt_template_key = self.example_prompt.input_variables[0] | |
for i in range(k): | |
score[i] = ngram_overlap_score( | |
inputs, [self.examples[i][first_prompt_template_key]] | |
) | |
while True: | |
arg_max = np.argmax(score) | |
if (score[arg_max] < self.threshold) or abs( | |
score[arg_max] - self.threshold | |
) < 1e-9: | |
break | |
examples.append(self.examples[arg_max]) | |
score[arg_max] = self.threshold - 1.0 | |
return examples | |