File size: 4,680 Bytes
f1e6b80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""Interface with the LangChain Hub."""
from __future__ import annotations
import json
from typing import Any, Optional, Sequence
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.prompts import BasePromptTemplate
def _get_client(
api_key: Optional[str] = None,
api_url: Optional[str] = None,
) -> Any:
try:
from langsmith import Client as LangSmithClient
ls_client = LangSmithClient(api_url, api_key=api_key)
if hasattr(ls_client, "push_prompt") and hasattr(ls_client, "pull_prompt"):
return ls_client
else:
from langchainhub import Client as LangChainHubClient
return LangChainHubClient(api_url, api_key=api_key)
except ImportError:
try:
from langchainhub import Client as LangChainHubClient
return LangChainHubClient(api_url, api_key=api_key)
except ImportError as e:
raise ImportError(
"Could not import langsmith or langchainhub (deprecated),"
"please install with `pip install langsmith`."
) from e
def push(
repo_full_name: str,
object: Any,
*,
api_url: Optional[str] = None,
api_key: Optional[str] = None,
parent_commit_hash: Optional[str] = None,
new_repo_is_public: bool = False,
new_repo_description: Optional[str] = None,
readme: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
) -> str:
"""
Push an object to the hub and returns the URL it can be viewed at in a browser.
:param repo_full_name: The full name of the prompt to push to in the format of
`owner/prompt_name` or `prompt_name`.
:param object: The LangChain to serialize and push to the hub.
:param api_url: The URL of the LangChain Hub API. Defaults to the hosted API service
if you have an api key set, or a localhost instance if not.
:param api_key: The API key to use to authenticate with the LangChain Hub API.
:param parent_commit_hash: The commit hash of the parent commit to push to. Defaults
to the latest commit automatically.
:param new_repo_is_public: Whether the prompt should be public. Defaults to
False (Private by default).
:param new_repo_description: The description of the prompt. Defaults to an empty
string.
"""
client = _get_client(api_key=api_key, api_url=api_url)
# Then it's langsmith
if hasattr(client, "push_prompt"):
return client.push_prompt(
repo_full_name,
object=object,
parent_commit_hash=parent_commit_hash,
is_public=new_repo_is_public,
description=new_repo_description,
readme=readme,
tags=tags,
)
# Then it's langchainhub
manifest_json = dumps(object)
message = client.push(
repo_full_name,
manifest_json,
parent_commit_hash=parent_commit_hash,
new_repo_is_public=new_repo_is_public,
new_repo_description=new_repo_description,
)
return message
def pull(
owner_repo_commit: str,
*,
include_model: Optional[bool] = None,
api_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> Any:
"""
Pull an object from the hub and returns it as a LangChain object.
:param owner_repo_commit: The full name of the prompt to pull from in the format of
`owner/prompt_name:commit_hash` or `owner/prompt_name`
or just `prompt_name` if it's your own prompt.
:param api_url: The URL of the LangChain Hub API. Defaults to the hosted API service
if you have an api key set, or a localhost instance if not.
:param api_key: The API key to use to authenticate with the LangChain Hub API.
"""
client = _get_client(api_key=api_key, api_url=api_url)
# Then it's langsmith
if hasattr(client, "pull_prompt"):
response = client.pull_prompt(owner_repo_commit, include_model=include_model)
return response
# Then it's langchainhub
if hasattr(client, "pull_repo"):
# >= 0.1.15
res_dict = client.pull_repo(owner_repo_commit)
obj = loads(json.dumps(res_dict["manifest"]))
if isinstance(obj, BasePromptTemplate):
if obj.metadata is None:
obj.metadata = {}
obj.metadata["lc_hub_owner"] = res_dict["owner"]
obj.metadata["lc_hub_repo"] = res_dict["repo"]
obj.metadata["lc_hub_commit_hash"] = res_dict["commit_hash"]
return obj
# Then it's < 0.1.15 langchainhub
resp: str = client.pull(owner_repo_commit)
return loads(resp)
|