Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import re | |
from abc import abstractmethod | |
from typing import Any, Dict, List, Optional, Sequence, Tuple | |
import numpy as np | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.outputs import Generation | |
from langchain_core.prompts import BasePromptTemplate | |
from langchain_core.pydantic_v1 import Field | |
from langchain_core.retrievers import BaseRetriever | |
from langchain.callbacks.manager import ( | |
CallbackManagerForChainRun, | |
) | |
from langchain.chains.base import Chain | |
from langchain.chains.flare.prompts import ( | |
PROMPT, | |
QUESTION_GENERATOR_PROMPT, | |
FinishedOutputParser, | |
) | |
from langchain.chains.llm import LLMChain | |
from langchain.llms.openai import OpenAI | |
class _ResponseChain(LLMChain): | |
"""Base class for chains that generate responses.""" | |
prompt: BasePromptTemplate = PROMPT | |
def input_keys(self) -> List[str]: | |
return self.prompt.input_variables | |
def generate_tokens_and_log_probs( | |
self, | |
_input: Dict[str, Any], | |
*, | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Tuple[Sequence[str], Sequence[float]]: | |
llm_result = self.generate([_input], run_manager=run_manager) | |
return self._extract_tokens_and_log_probs(llm_result.generations[0]) | |
def _extract_tokens_and_log_probs( | |
self, generations: List[Generation] | |
) -> Tuple[Sequence[str], Sequence[float]]: | |
"""Extract tokens and log probs from response.""" | |
class _OpenAIResponseChain(_ResponseChain): | |
"""Chain that generates responses from user input and context.""" | |
llm: OpenAI = Field( | |
default_factory=lambda: OpenAI( | |
max_tokens=32, model_kwargs={"logprobs": 1}, temperature=0 | |
) | |
) | |
def _extract_tokens_and_log_probs( | |
self, generations: List[Generation] | |
) -> Tuple[Sequence[str], Sequence[float]]: | |
tokens = [] | |
log_probs = [] | |
for gen in generations: | |
if gen.generation_info is None: | |
raise ValueError | |
tokens.extend(gen.generation_info["logprobs"]["tokens"]) | |
log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"]) | |
return tokens, log_probs | |
class QuestionGeneratorChain(LLMChain): | |
"""Chain that generates questions from uncertain spans.""" | |
prompt: BasePromptTemplate = QUESTION_GENERATOR_PROMPT | |
"""Prompt template for the chain.""" | |
def input_keys(self) -> List[str]: | |
"""Input keys for the chain.""" | |
return ["user_input", "context", "response"] | |
def _low_confidence_spans( | |
tokens: Sequence[str], | |
log_probs: Sequence[float], | |
min_prob: float, | |
min_token_gap: int, | |
num_pad_tokens: int, | |
) -> List[str]: | |
_low_idx = np.where(np.exp(log_probs) < min_prob)[0] | |
low_idx = [i for i in _low_idx if re.search(r"\w", tokens[i])] | |
if len(low_idx) == 0: | |
return [] | |
spans = [[low_idx[0], low_idx[0] + num_pad_tokens + 1]] | |
for i, idx in enumerate(low_idx[1:]): | |
end = idx + num_pad_tokens + 1 | |
if idx - low_idx[i] < min_token_gap: | |
spans[-1][1] = end | |
else: | |
spans.append([idx, end]) | |
return ["".join(tokens[start:end]) for start, end in spans] | |
class FlareChain(Chain): | |
"""Chain that combines a retriever, a question generator, | |
and a response generator.""" | |
question_generator_chain: QuestionGeneratorChain | |
"""Chain that generates questions from uncertain spans.""" | |
response_chain: _ResponseChain = Field(default_factory=_OpenAIResponseChain) | |
"""Chain that generates responses from user input and context.""" | |
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser) | |
"""Parser that determines whether the chain is finished.""" | |
retriever: BaseRetriever | |
"""Retriever that retrieves relevant documents from a user input.""" | |
min_prob: float = 0.2 | |
"""Minimum probability for a token to be considered low confidence.""" | |
min_token_gap: int = 5 | |
"""Minimum number of tokens between two low confidence spans.""" | |
num_pad_tokens: int = 2 | |
"""Number of tokens to pad around a low confidence span.""" | |
max_iter: int = 10 | |
"""Maximum number of iterations.""" | |
start_with_retrieval: bool = True | |
"""Whether to start with retrieval.""" | |
def input_keys(self) -> List[str]: | |
"""Input keys for the chain.""" | |
return ["user_input"] | |
def output_keys(self) -> List[str]: | |
"""Output keys for the chain.""" | |
return ["response"] | |
def _do_generation( | |
self, | |
questions: List[str], | |
user_input: str, | |
response: str, | |
_run_manager: CallbackManagerForChainRun, | |
) -> Tuple[str, bool]: | |
callbacks = _run_manager.get_child() | |
docs = [] | |
for question in questions: | |
docs.extend(self.retriever.get_relevant_documents(question)) | |
context = "\n\n".join(d.page_content for d in docs) | |
result = self.response_chain.predict( | |
user_input=user_input, | |
context=context, | |
response=response, | |
callbacks=callbacks, | |
) | |
marginal, finished = self.output_parser.parse(result) | |
return marginal, finished | |
def _do_retrieval( | |
self, | |
low_confidence_spans: List[str], | |
_run_manager: CallbackManagerForChainRun, | |
user_input: str, | |
response: str, | |
initial_response: str, | |
) -> Tuple[str, bool]: | |
question_gen_inputs = [ | |
{ | |
"user_input": user_input, | |
"current_response": initial_response, | |
"uncertain_span": span, | |
} | |
for span in low_confidence_spans | |
] | |
callbacks = _run_manager.get_child() | |
question_gen_outputs = self.question_generator_chain.apply( | |
question_gen_inputs, callbacks=callbacks | |
) | |
questions = [ | |
output[self.question_generator_chain.output_keys[0]] | |
for output in question_gen_outputs | |
] | |
_run_manager.on_text( | |
f"Generated Questions: {questions}", color="yellow", end="\n" | |
) | |
return self._do_generation(questions, user_input, response, _run_manager) | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
user_input = inputs[self.input_keys[0]] | |
response = "" | |
for i in range(self.max_iter): | |
_run_manager.on_text( | |
f"Current Response: {response}", color="blue", end="\n" | |
) | |
_input = {"user_input": user_input, "context": "", "response": response} | |
tokens, log_probs = self.response_chain.generate_tokens_and_log_probs( | |
_input, run_manager=_run_manager | |
) | |
low_confidence_spans = _low_confidence_spans( | |
tokens, | |
log_probs, | |
self.min_prob, | |
self.min_token_gap, | |
self.num_pad_tokens, | |
) | |
initial_response = response.strip() + " " + "".join(tokens) | |
if not low_confidence_spans: | |
response = initial_response | |
final_response, finished = self.output_parser.parse(response) | |
if finished: | |
return {self.output_keys[0]: final_response} | |
continue | |
marginal, finished = self._do_retrieval( | |
low_confidence_spans, | |
_run_manager, | |
user_input, | |
response, | |
initial_response, | |
) | |
response = response.strip() + " " + marginal | |
if finished: | |
break | |
return {self.output_keys[0]: response} | |
def from_llm( | |
cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any | |
) -> FlareChain: | |
"""Creates a FlareChain from a language model. | |
Args: | |
llm: Language model to use. | |
max_generation_len: Maximum length of the generated response. | |
**kwargs: Additional arguments to pass to the constructor. | |
Returns: | |
FlareChain class with the given language model. | |
""" | |
question_gen_chain = QuestionGeneratorChain(llm=llm) | |
response_llm = OpenAI( | |
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0 | |
) | |
response_chain = _OpenAIResponseChain(llm=response_llm) | |
return cls( | |
question_generator_chain=question_gen_chain, | |
response_chain=response_chain, | |
**kwargs, | |
) | |