Spaces:
Runtime error
Runtime error
"""Functionality for loading chains.""" | |
import json | |
from pathlib import Path | |
from typing import Any, Union | |
import yaml | |
from langchain_core.prompts.loading import ( | |
_load_output_parser, | |
load_prompt, | |
load_prompt_from_config, | |
) | |
from langchain.chains import ReduceDocumentsChain | |
from langchain.chains.api.base import APIChain | |
from langchain.chains.base import Chain | |
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain | |
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain | |
from langchain.chains.combine_documents.refine import RefineDocumentsChain | |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain | |
from langchain.chains.graph_qa.cypher import GraphCypherQAChain | |
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder | |
from langchain.chains.llm import LLMChain | |
from langchain.chains.llm_checker.base import LLMCheckerChain | |
from langchain.chains.llm_math.base import LLMMathChain | |
from langchain.chains.llm_requests import LLMRequestsChain | |
from langchain.chains.qa_with_sources.base import QAWithSourcesChain | |
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain | |
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain | |
from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA | |
from langchain.llms.loading import load_llm, load_llm_from_config | |
from langchain.utilities.loading import try_load_from_hub | |
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/" | |
def _load_llm_chain(config: dict, **kwargs: Any) -> LLMChain: | |
"""Load LLM chain from config dict.""" | |
if "llm" in config: | |
llm_config = config.pop("llm") | |
llm = load_llm_from_config(llm_config) | |
elif "llm_path" in config: | |
llm = load_llm(config.pop("llm_path")) | |
else: | |
raise ValueError("One of `llm` or `llm_path` must be present.") | |
if "prompt" in config: | |
prompt_config = config.pop("prompt") | |
prompt = load_prompt_from_config(prompt_config) | |
elif "prompt_path" in config: | |
prompt = load_prompt(config.pop("prompt_path")) | |
else: | |
raise ValueError("One of `prompt` or `prompt_path` must be present.") | |
_load_output_parser(config) | |
return LLMChain(llm=llm, prompt=prompt, **config) | |
def _load_hyde_chain(config: dict, **kwargs: Any) -> HypotheticalDocumentEmbedder: | |
"""Load hypothetical document embedder chain from config dict.""" | |
if "llm_chain" in config: | |
llm_chain_config = config.pop("llm_chain") | |
llm_chain = load_chain_from_config(llm_chain_config) | |
elif "llm_chain_path" in config: | |
llm_chain = load_chain(config.pop("llm_chain_path")) | |
else: | |
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") | |
if "embeddings" in kwargs: | |
embeddings = kwargs.pop("embeddings") | |
else: | |
raise ValueError("`embeddings` must be present.") | |
return HypotheticalDocumentEmbedder( | |
llm_chain=llm_chain, base_embeddings=embeddings, **config | |
) | |
def _load_stuff_documents_chain(config: dict, **kwargs: Any) -> StuffDocumentsChain: | |
if "llm_chain" in config: | |
llm_chain_config = config.pop("llm_chain") | |
llm_chain = load_chain_from_config(llm_chain_config) | |
elif "llm_chain_path" in config: | |
llm_chain = load_chain(config.pop("llm_chain_path")) | |
else: | |
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") | |
if not isinstance(llm_chain, LLMChain): | |
raise ValueError(f"Expected LLMChain, got {llm_chain}") | |
if "document_prompt" in config: | |
prompt_config = config.pop("document_prompt") | |
document_prompt = load_prompt_from_config(prompt_config) | |
elif "document_prompt_path" in config: | |
document_prompt = load_prompt(config.pop("document_prompt_path")) | |
else: | |
raise ValueError( | |
"One of `document_prompt` or `document_prompt_path` must be present." | |
) | |
return StuffDocumentsChain( | |
llm_chain=llm_chain, document_prompt=document_prompt, **config | |
) | |
def _load_map_reduce_documents_chain( | |
config: dict, **kwargs: Any | |
) -> MapReduceDocumentsChain: | |
if "llm_chain" in config: | |
llm_chain_config = config.pop("llm_chain") | |
llm_chain = load_chain_from_config(llm_chain_config) | |
elif "llm_chain_path" in config: | |
llm_chain = load_chain(config.pop("llm_chain_path")) | |
else: | |
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") | |
if not isinstance(llm_chain, LLMChain): | |
raise ValueError(f"Expected LLMChain, got {llm_chain}") | |
if "reduce_documents_chain" in config: | |
reduce_documents_chain = load_chain_from_config( | |
config.pop("reduce_documents_chain") | |
) | |
elif "reduce_documents_chain_path" in config: | |
reduce_documents_chain = load_chain(config.pop("reduce_documents_chain_path")) | |
else: | |
reduce_documents_chain = _load_reduce_documents_chain(config) | |
return MapReduceDocumentsChain( | |
llm_chain=llm_chain, | |
reduce_documents_chain=reduce_documents_chain, | |
**config, | |
) | |
def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocumentsChain: | |
combine_documents_chain = None | |
collapse_documents_chain = None | |
if "combine_documents_chain" in config: | |
combine_document_chain_config = config.pop("combine_documents_chain") | |
combine_documents_chain = load_chain_from_config(combine_document_chain_config) | |
elif "combine_document_chain" in config: | |
combine_document_chain_config = config.pop("combine_document_chain") | |
combine_documents_chain = load_chain_from_config(combine_document_chain_config) | |
elif "combine_documents_chain_path" in config: | |
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) | |
elif "combine_document_chain_path" in config: | |
combine_documents_chain = load_chain(config.pop("combine_document_chain_path")) | |
else: | |
raise ValueError( | |
"One of `combine_documents_chain` or " | |
"`combine_documents_chain_path` must be present." | |
) | |
if "collapse_documents_chain" in config: | |
collapse_document_chain_config = config.pop("collapse_documents_chain") | |
if collapse_document_chain_config is None: | |
collapse_documents_chain = None | |
else: | |
collapse_documents_chain = load_chain_from_config( | |
collapse_document_chain_config | |
) | |
elif "collapse_documents_chain_path" in config: | |
collapse_documents_chain = load_chain( | |
config.pop("collapse_documents_chain_path") | |
) | |
elif "collapse_document_chain" in config: | |
collapse_document_chain_config = config.pop("collapse_document_chain") | |
if collapse_document_chain_config is None: | |
collapse_documents_chain = None | |
else: | |
collapse_documents_chain = load_chain_from_config( | |
collapse_document_chain_config | |
) | |
elif "collapse_document_chain_path" in config: | |
collapse_documents_chain = load_chain( | |
config.pop("collapse_document_chain_path") | |
) | |
return ReduceDocumentsChain( | |
combine_documents_chain=combine_documents_chain, | |
collapse_documents_chain=collapse_documents_chain, | |
**config, | |
) | |
def _load_llm_bash_chain(config: dict, **kwargs: Any) -> Any: | |
from langchain_experimental.llm_bash.base import LLMBashChain | |
llm_chain = None | |
if "llm_chain" in config: | |
llm_chain_config = config.pop("llm_chain") | |
llm_chain = load_chain_from_config(llm_chain_config) | |
elif "llm_chain_path" in config: | |
llm_chain = load_chain(config.pop("llm_chain_path")) | |
# llm attribute is deprecated in favor of llm_chain, here to support old configs | |
elif "llm" in config: | |
llm_config = config.pop("llm") | |
llm = load_llm_from_config(llm_config) | |
# llm_path attribute is deprecated in favor of llm_chain_path, | |
# its to support old configs | |
elif "llm_path" in config: | |
llm = load_llm(config.pop("llm_path")) | |
else: | |
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") | |
if "prompt" in config: | |
prompt_config = config.pop("prompt") | |
prompt = load_prompt_from_config(prompt_config) | |
elif "prompt_path" in config: | |
prompt = load_prompt(config.pop("prompt_path")) | |
if llm_chain: | |
return LLMBashChain(llm_chain=llm_chain, prompt=prompt, **config) | |
else: | |
return LLMBashChain(llm=llm, prompt=prompt, **config) | |
def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain: | |
if "llm" in config: | |
llm_config = config.pop("llm") | |
llm = load_llm_from_config(llm_config) | |
elif "llm_path" in config: | |
llm = load_llm(config.pop("llm_path")) | |
else: | |
raise ValueError("One of `llm` or `llm_path` must be present.") | |
if "create_draft_answer_prompt" in config: | |
create_draft_answer_prompt_config = config.pop("create_draft_answer_prompt") | |
create_draft_answer_prompt = load_prompt_from_config( | |
create_draft_answer_prompt_config | |
) | |
elif "create_draft_answer_prompt_path" in config: | |
create_draft_answer_prompt = load_prompt( | |
config.pop("create_draft_answer_prompt_path") | |
) | |
if "list_assertions_prompt" in config: | |
list_assertions_prompt_config = config.pop("list_assertions_prompt") | |
list_assertions_prompt = load_prompt_from_config(list_assertions_prompt_config) | |
elif "list_assertions_prompt_path" in config: | |
list_assertions_prompt = load_prompt(config.pop("list_assertions_prompt_path")) | |
if "check_assertions_prompt" in config: | |
check_assertions_prompt_config = config.pop("check_assertions_prompt") | |
check_assertions_prompt = load_prompt_from_config( | |
check_assertions_prompt_config | |
) | |
elif "check_assertions_prompt_path" in config: | |
check_assertions_prompt = load_prompt( | |
config.pop("check_assertions_prompt_path") | |
) | |
if "revised_answer_prompt" in config: | |
revised_answer_prompt_config = config.pop("revised_answer_prompt") | |
revised_answer_prompt = load_prompt_from_config(revised_answer_prompt_config) | |
elif "revised_answer_prompt_path" in config: | |
revised_answer_prompt = load_prompt(config.pop("revised_answer_prompt_path")) | |
return LLMCheckerChain( | |
llm=llm, | |
create_draft_answer_prompt=create_draft_answer_prompt, | |
list_assertions_prompt=list_assertions_prompt, | |
check_assertions_prompt=check_assertions_prompt, | |
revised_answer_prompt=revised_answer_prompt, | |
**config, | |
) | |
def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain: | |
llm_chain = None | |
if "llm_chain" in config: | |
llm_chain_config = config.pop("llm_chain") | |
llm_chain = load_chain_from_config(llm_chain_config) | |
elif "llm_chain_path" in config: | |
llm_chain = load_chain(config.pop("llm_chain_path")) | |
# llm attribute is deprecated in favor of llm_chain, here to support old configs | |
elif "llm" in config: | |
llm_config = config.pop("llm") | |
llm = load_llm_from_config(llm_config) | |
# llm_path attribute is deprecated in favor of llm_chain_path, | |
# its to support old configs | |
elif "llm_path" in config: | |
llm = load_llm(config.pop("llm_path")) | |
else: | |
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") | |
if "prompt" in config: | |
prompt_config = config.pop("prompt") | |
prompt = load_prompt_from_config(prompt_config) | |
elif "prompt_path" in config: | |
prompt = load_prompt(config.pop("prompt_path")) | |
if llm_chain: | |
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) | |
else: | |
return LLMMathChain(llm=llm, prompt=prompt, **config) | |
def _load_map_rerank_documents_chain( | |
config: dict, **kwargs: Any | |
) -> MapRerankDocumentsChain: | |
if "llm_chain" in config: | |
llm_chain_config = config.pop("llm_chain") | |
llm_chain = load_chain_from_config(llm_chain_config) | |
elif "llm_chain_path" in config: | |
llm_chain = load_chain(config.pop("llm_chain_path")) | |
else: | |
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") | |
return MapRerankDocumentsChain(llm_chain=llm_chain, **config) | |
def _load_pal_chain(config: dict, **kwargs: Any) -> Any: | |
from langchain_experimental.pal_chain import PALChain | |
if "llm_chain" in config: | |
llm_chain_config = config.pop("llm_chain") | |
llm_chain = load_chain_from_config(llm_chain_config) | |
elif "llm_chain_path" in config: | |
llm_chain = load_chain(config.pop("llm_chain_path")) | |
else: | |
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") | |
return PALChain(llm_chain=llm_chain, **config) | |
def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocumentsChain: | |
if "initial_llm_chain" in config: | |
initial_llm_chain_config = config.pop("initial_llm_chain") | |
initial_llm_chain = load_chain_from_config(initial_llm_chain_config) | |
elif "initial_llm_chain_path" in config: | |
initial_llm_chain = load_chain(config.pop("initial_llm_chain_path")) | |
else: | |
raise ValueError( | |
"One of `initial_llm_chain` or `initial_llm_chain_path` must be present." | |
) | |
if "refine_llm_chain" in config: | |
refine_llm_chain_config = config.pop("refine_llm_chain") | |
refine_llm_chain = load_chain_from_config(refine_llm_chain_config) | |
elif "refine_llm_chain_path" in config: | |
refine_llm_chain = load_chain(config.pop("refine_llm_chain_path")) | |
else: | |
raise ValueError( | |
"One of `refine_llm_chain` or `refine_llm_chain_path` must be present." | |
) | |
if "document_prompt" in config: | |
prompt_config = config.pop("document_prompt") | |
document_prompt = load_prompt_from_config(prompt_config) | |
elif "document_prompt_path" in config: | |
document_prompt = load_prompt(config.pop("document_prompt_path")) | |
return RefineDocumentsChain( | |
initial_llm_chain=initial_llm_chain, | |
refine_llm_chain=refine_llm_chain, | |
document_prompt=document_prompt, | |
**config, | |
) | |
def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesChain: | |
if "combine_documents_chain" in config: | |
combine_documents_chain_config = config.pop("combine_documents_chain") | |
combine_documents_chain = load_chain_from_config(combine_documents_chain_config) | |
elif "combine_documents_chain_path" in config: | |
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) | |
else: | |
raise ValueError( | |
"One of `combine_documents_chain` or " | |
"`combine_documents_chain_path` must be present." | |
) | |
return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config) | |
def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any: | |
from langchain_experimental.sql import SQLDatabaseChain | |
if "database" in kwargs: | |
database = kwargs.pop("database") | |
else: | |
raise ValueError("`database` must be present.") | |
if "llm_chain" in config: | |
llm_chain_config = config.pop("llm_chain") | |
chain = load_chain_from_config(llm_chain_config) | |
return SQLDatabaseChain(llm_chain=chain, database=database, **config) | |
if "llm" in config: | |
llm_config = config.pop("llm") | |
llm = load_llm_from_config(llm_config) | |
elif "llm_path" in config: | |
llm = load_llm(config.pop("llm_path")) | |
else: | |
raise ValueError("One of `llm` or `llm_path` must be present.") | |
if "prompt" in config: | |
prompt_config = config.pop("prompt") | |
prompt = load_prompt_from_config(prompt_config) | |
else: | |
prompt = None | |
return SQLDatabaseChain.from_llm(llm, database, prompt=prompt, **config) | |
def _load_vector_db_qa_with_sources_chain( | |
config: dict, **kwargs: Any | |
) -> VectorDBQAWithSourcesChain: | |
if "vectorstore" in kwargs: | |
vectorstore = kwargs.pop("vectorstore") | |
else: | |
raise ValueError("`vectorstore` must be present.") | |
if "combine_documents_chain" in config: | |
combine_documents_chain_config = config.pop("combine_documents_chain") | |
combine_documents_chain = load_chain_from_config(combine_documents_chain_config) | |
elif "combine_documents_chain_path" in config: | |
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) | |
else: | |
raise ValueError( | |
"One of `combine_documents_chain` or " | |
"`combine_documents_chain_path` must be present." | |
) | |
return VectorDBQAWithSourcesChain( | |
combine_documents_chain=combine_documents_chain, | |
vectorstore=vectorstore, | |
**config, | |
) | |
def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA: | |
if "retriever" in kwargs: | |
retriever = kwargs.pop("retriever") | |
else: | |
raise ValueError("`retriever` must be present.") | |
if "combine_documents_chain" in config: | |
combine_documents_chain_config = config.pop("combine_documents_chain") | |
combine_documents_chain = load_chain_from_config(combine_documents_chain_config) | |
elif "combine_documents_chain_path" in config: | |
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) | |
else: | |
raise ValueError( | |
"One of `combine_documents_chain` or " | |
"`combine_documents_chain_path` must be present." | |
) | |
return RetrievalQA( | |
combine_documents_chain=combine_documents_chain, | |
retriever=retriever, | |
**config, | |
) | |
def _load_retrieval_qa_with_sources_chain( | |
config: dict, **kwargs: Any | |
) -> RetrievalQAWithSourcesChain: | |
if "retriever" in kwargs: | |
retriever = kwargs.pop("retriever") | |
else: | |
raise ValueError("`retriever` must be present.") | |
if "combine_documents_chain" in config: | |
combine_documents_chain_config = config.pop("combine_documents_chain") | |
combine_documents_chain = load_chain_from_config(combine_documents_chain_config) | |
elif "combine_documents_chain_path" in config: | |
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) | |
else: | |
raise ValueError( | |
"One of `combine_documents_chain` or " | |
"`combine_documents_chain_path` must be present." | |
) | |
return RetrievalQAWithSourcesChain( | |
combine_documents_chain=combine_documents_chain, | |
retriever=retriever, | |
**config, | |
) | |
def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA: | |
if "vectorstore" in kwargs: | |
vectorstore = kwargs.pop("vectorstore") | |
else: | |
raise ValueError("`vectorstore` must be present.") | |
if "combine_documents_chain" in config: | |
combine_documents_chain_config = config.pop("combine_documents_chain") | |
combine_documents_chain = load_chain_from_config(combine_documents_chain_config) | |
elif "combine_documents_chain_path" in config: | |
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) | |
else: | |
raise ValueError( | |
"One of `combine_documents_chain` or " | |
"`combine_documents_chain_path` must be present." | |
) | |
return VectorDBQA( | |
combine_documents_chain=combine_documents_chain, | |
vectorstore=vectorstore, | |
**config, | |
) | |
def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain: | |
if "graph" in kwargs: | |
graph = kwargs.pop("graph") | |
else: | |
raise ValueError("`graph` must be present.") | |
if "cypher_generation_chain" in config: | |
cypher_generation_chain_config = config.pop("cypher_generation_chain") | |
cypher_generation_chain = load_chain_from_config(cypher_generation_chain_config) | |
else: | |
raise ValueError("`cypher_generation_chain` must be present.") | |
if "qa_chain" in config: | |
qa_chain_config = config.pop("qa_chain") | |
qa_chain = load_chain_from_config(qa_chain_config) | |
else: | |
raise ValueError("`qa_chain` must be present.") | |
return GraphCypherQAChain( | |
graph=graph, | |
cypher_generation_chain=cypher_generation_chain, | |
qa_chain=qa_chain, | |
**config, | |
) | |
def _load_api_chain(config: dict, **kwargs: Any) -> APIChain: | |
if "api_request_chain" in config: | |
api_request_chain_config = config.pop("api_request_chain") | |
api_request_chain = load_chain_from_config(api_request_chain_config) | |
elif "api_request_chain_path" in config: | |
api_request_chain = load_chain(config.pop("api_request_chain_path")) | |
else: | |
raise ValueError( | |
"One of `api_request_chain` or `api_request_chain_path` must be present." | |
) | |
if "api_answer_chain" in config: | |
api_answer_chain_config = config.pop("api_answer_chain") | |
api_answer_chain = load_chain_from_config(api_answer_chain_config) | |
elif "api_answer_chain_path" in config: | |
api_answer_chain = load_chain(config.pop("api_answer_chain_path")) | |
else: | |
raise ValueError( | |
"One of `api_answer_chain` or `api_answer_chain_path` must be present." | |
) | |
if "requests_wrapper" in kwargs: | |
requests_wrapper = kwargs.pop("requests_wrapper") | |
else: | |
raise ValueError("`requests_wrapper` must be present.") | |
return APIChain( | |
api_request_chain=api_request_chain, | |
api_answer_chain=api_answer_chain, | |
requests_wrapper=requests_wrapper, | |
**config, | |
) | |
def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain: | |
if "llm_chain" in config: | |
llm_chain_config = config.pop("llm_chain") | |
llm_chain = load_chain_from_config(llm_chain_config) | |
elif "llm_chain_path" in config: | |
llm_chain = load_chain(config.pop("llm_chain_path")) | |
else: | |
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") | |
if "requests_wrapper" in kwargs: | |
requests_wrapper = kwargs.pop("requests_wrapper") | |
return LLMRequestsChain( | |
llm_chain=llm_chain, requests_wrapper=requests_wrapper, **config | |
) | |
else: | |
return LLMRequestsChain(llm_chain=llm_chain, **config) | |
type_to_loader_dict = { | |
"api_chain": _load_api_chain, | |
"hyde_chain": _load_hyde_chain, | |
"llm_chain": _load_llm_chain, | |
"llm_bash_chain": _load_llm_bash_chain, | |
"llm_checker_chain": _load_llm_checker_chain, | |
"llm_math_chain": _load_llm_math_chain, | |
"llm_requests_chain": _load_llm_requests_chain, | |
"pal_chain": _load_pal_chain, | |
"qa_with_sources_chain": _load_qa_with_sources_chain, | |
"stuff_documents_chain": _load_stuff_documents_chain, | |
"map_reduce_documents_chain": _load_map_reduce_documents_chain, | |
"reduce_documents_chain": _load_reduce_documents_chain, | |
"map_rerank_documents_chain": _load_map_rerank_documents_chain, | |
"refine_documents_chain": _load_refine_documents_chain, | |
"sql_database_chain": _load_sql_database_chain, | |
"vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain, | |
"vector_db_qa": _load_vector_db_qa, | |
"retrieval_qa": _load_retrieval_qa, | |
"retrieval_qa_with_sources_chain": _load_retrieval_qa_with_sources_chain, | |
"graph_cypher_chain": _load_graph_cypher_chain, | |
} | |
def load_chain_from_config(config: dict, **kwargs: Any) -> Chain: | |
"""Load chain from Config Dict.""" | |
if "_type" not in config: | |
raise ValueError("Must specify a chain Type in config") | |
config_type = config.pop("_type") | |
if config_type not in type_to_loader_dict: | |
raise ValueError(f"Loading {config_type} chain not supported") | |
chain_loader = type_to_loader_dict[config_type] | |
return chain_loader(config, **kwargs) | |
def load_chain(path: Union[str, Path], **kwargs: Any) -> Chain: | |
"""Unified method for loading a chain from LangChainHub or local fs.""" | |
if hub_result := try_load_from_hub( | |
path, _load_chain_from_file, "chains", {"json", "yaml"}, **kwargs | |
): | |
return hub_result | |
else: | |
return _load_chain_from_file(path, **kwargs) | |
def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain: | |
"""Load chain from file.""" | |
# Convert file to Path object. | |
if isinstance(file, str): | |
file_path = Path(file) | |
else: | |
file_path = file | |
# Load from either json or yaml. | |
if file_path.suffix == ".json": | |
with open(file_path) as f: | |
config = json.load(f) | |
elif file_path.suffix == ".yaml": | |
with open(file_path, "r") as f: | |
config = yaml.safe_load(f) | |
else: | |
raise ValueError("File type must be json or yaml") | |
# Override default 'verbose' and 'memory' for the chain | |
if "verbose" in kwargs: | |
config["verbose"] = kwargs.pop("verbose") | |
if "memory" in kwargs: | |
config["memory"] = kwargs.pop("memory") | |
# Load the chain from the config now. | |
return load_chain_from_config(config, **kwargs) | |