Spaces:
Runtime error
Runtime error
from typing import Any, List, Optional | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate | |
from langchain_core.pydantic_v1 import BaseModel | |
from langchain.chains.base import Chain | |
from langchain.chains.llm import LLMChain | |
from langchain.chains.openai_functions.utils import ( | |
_convert_schema, | |
_resolve_schema_references, | |
get_llm_kwargs, | |
) | |
from langchain.output_parsers.openai_functions import ( | |
JsonKeyOutputFunctionsParser, | |
PydanticAttrOutputFunctionsParser, | |
) | |
def _get_extraction_function(entity_schema: dict) -> dict: | |
return { | |
"name": "information_extraction", | |
"description": "Extracts the relevant information from the passage.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"info": {"type": "array", "items": _convert_schema(entity_schema)} | |
}, | |
"required": ["info"], | |
}, | |
} | |
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned \ | |
in the following passage together with their properties. | |
Only extract the properties mentioned in the 'information_extraction' function. | |
If a property is not present and is not required in the function parameters, do not include it in the output. | |
Passage: | |
{input} | |
""" # noqa: E501 | |
def create_extraction_chain( | |
schema: dict, | |
llm: BaseLanguageModel, | |
prompt: Optional[BasePromptTemplate] = None, | |
tags: Optional[List[str]] = None, | |
verbose: bool = False, | |
) -> Chain: | |
"""Creates a chain that extracts information from a passage. | |
Args: | |
schema: The schema of the entities to extract. | |
llm: The language model to use. | |
prompt: The prompt to use for extraction. | |
verbose: Whether to run in verbose mode. In verbose mode, some intermediate | |
logs will be printed to the console. Defaults to the global `verbose` value, | |
accessible via `langchain.globals.get_verbose()`. | |
Returns: | |
Chain that can be used to extract information from a passage. | |
""" | |
function = _get_extraction_function(schema) | |
extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) | |
output_parser = JsonKeyOutputFunctionsParser(key_name="info") | |
llm_kwargs = get_llm_kwargs(function) | |
chain = LLMChain( | |
llm=llm, | |
prompt=extraction_prompt, | |
llm_kwargs=llm_kwargs, | |
output_parser=output_parser, | |
tags=tags, | |
verbose=verbose, | |
) | |
return chain | |
def create_extraction_chain_pydantic( | |
pydantic_schema: Any, | |
llm: BaseLanguageModel, | |
prompt: Optional[BasePromptTemplate] = None, | |
verbose: bool = False, | |
) -> Chain: | |
"""Creates a chain that extracts information from a passage using pydantic schema. | |
Args: | |
pydantic_schema: The pydantic schema of the entities to extract. | |
llm: The language model to use. | |
prompt: The prompt to use for extraction. | |
verbose: Whether to run in verbose mode. In verbose mode, some intermediate | |
logs will be printed to the console. Defaults to the global `verbose` value, | |
accessible via `langchain.globals.get_verbose()` | |
Returns: | |
Chain that can be used to extract information from a passage. | |
""" | |
class PydanticSchema(BaseModel): | |
info: List[pydantic_schema] # type: ignore | |
openai_schema = pydantic_schema.schema() | |
openai_schema = _resolve_schema_references( | |
openai_schema, openai_schema.get("definitions", {}) | |
) | |
function = _get_extraction_function(openai_schema) | |
extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) | |
output_parser = PydanticAttrOutputFunctionsParser( | |
pydantic_schema=PydanticSchema, attr_name="info" | |
) | |
llm_kwargs = get_llm_kwargs(function) | |
chain = LLMChain( | |
llm=llm, | |
prompt=extraction_prompt, | |
llm_kwargs=llm_kwargs, | |
output_parser=output_parser, | |
verbose=verbose, | |
) | |
return chain | |