Spaces:
Runtime error
Runtime error
from typing import ( | |
Any, | |
List, | |
Optional, | |
Tuple, | |
Type, | |
TypedDict, | |
Union, | |
) | |
from langchain_core.documents import Document | |
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config | |
from langchain_ai21.ai21_base import AI21Base | |
ANSWER_NOT_IN_CONTEXT_RESPONSE = "Answer not in context" | |
ContextType = Union[str, List[Union[Document, str]]] | |
class ContextualAnswerInput(TypedDict): | |
context: ContextType | |
question: str | |
class AI21ContextualAnswers(RunnableSerializable[ContextualAnswerInput, str], AI21Base): | |
class Config: | |
"""Configuration for this pydantic object.""" | |
arbitrary_types_allowed = True | |
def InputType(self) -> Type[ContextualAnswerInput]: | |
"""Get the input type for this runnable.""" | |
return ContextualAnswerInput | |
def OutputType(self) -> Type[str]: | |
"""Get the input type for this runnable.""" | |
return str | |
def invoke( | |
self, | |
input: ContextualAnswerInput, | |
config: Optional[RunnableConfig] = None, | |
response_if_no_answer_found: str = ANSWER_NOT_IN_CONTEXT_RESPONSE, | |
**kwargs: Any, | |
) -> str: | |
config = ensure_config(config) | |
return self._call_with_config( | |
func=lambda inner_input: self._call_contextual_answers( | |
inner_input, response_if_no_answer_found | |
), | |
input=input, | |
config=config, | |
run_type="llm", | |
) | |
def _call_contextual_answers( | |
self, | |
input: ContextualAnswerInput, | |
response_if_no_answer_found: str, | |
) -> str: | |
context, question = self._convert_input(input) | |
response = self.client.answer.create(context=context, question=question) | |
if response.answer is None: | |
return response_if_no_answer_found | |
return response.answer | |
def _convert_input(self, input: ContextualAnswerInput) -> Tuple[str, str]: | |
context, question = self._extract_context_and_question(input) | |
context = self._parse_context(context) | |
return context, question | |
def _extract_context_and_question( | |
self, | |
input: ContextualAnswerInput, | |
) -> Tuple[ContextType, str]: | |
context = input.get("context") | |
question = input.get("question") | |
if not context or not question: | |
raise ValueError( | |
f"Input must contain a 'context' and 'question' fields. Got {input}" | |
) | |
if not isinstance(context, list) and not isinstance(context, str): | |
raise ValueError( | |
f"Expected input to be a list of strings or Documents." | |
f" Received {type(input)}" | |
) | |
return context, question | |
def _parse_context(self, context: ContextType) -> str: | |
if isinstance(context, str): | |
return context | |
docs = [ | |
item.page_content if isinstance(item, Document) else item | |
for item in context | |
] | |
return "\n".join(docs) | |