Spaces:
Runtime error
Runtime error
"""Hypothetical Document Embeddings. | |
https://arxiv.org/abs/2212.10496 | |
""" | |
from __future__ import annotations | |
from typing import Any, Dict, List, Optional | |
import numpy as np | |
from langchain_core.embeddings import Embeddings | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.pydantic_v1 import Extra | |
from langchain.callbacks.manager import CallbackManagerForChainRun | |
from langchain.chains.base import Chain | |
from langchain.chains.hyde.prompts import PROMPT_MAP | |
from langchain.chains.llm import LLMChain | |
class HypotheticalDocumentEmbedder(Chain, Embeddings): | |
"""Generate hypothetical document for query, and then embed that. | |
Based on https://arxiv.org/abs/2212.10496 | |
""" | |
base_embeddings: Embeddings | |
llm_chain: LLMChain | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
arbitrary_types_allowed = True | |
def input_keys(self) -> List[str]: | |
"""Input keys for Hyde's LLM chain.""" | |
return self.llm_chain.input_keys | |
def output_keys(self) -> List[str]: | |
"""Output keys for Hyde's LLM chain.""" | |
return self.llm_chain.output_keys | |
def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
"""Call the base embeddings.""" | |
return self.base_embeddings.embed_documents(texts) | |
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]: | |
"""Combine embeddings into final embeddings.""" | |
return list(np.array(embeddings).mean(axis=0)) | |
def embed_query(self, text: str) -> List[float]: | |
"""Generate a hypothetical document and embedded it.""" | |
var_name = self.llm_chain.input_keys[0] | |
result = self.llm_chain.generate([{var_name: text}]) | |
documents = [generation.text for generation in result.generations[0]] | |
embeddings = self.embed_documents(documents) | |
return self.combine_embeddings(embeddings) | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, str]: | |
"""Call the internal llm chain.""" | |
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
return self.llm_chain(inputs, callbacks=_run_manager.get_child()) | |
def from_llm( | |
cls, | |
llm: BaseLanguageModel, | |
base_embeddings: Embeddings, | |
prompt_key: str, | |
**kwargs: Any, | |
) -> HypotheticalDocumentEmbedder: | |
"""Load and use LLMChain for a specific prompt key.""" | |
prompt = PROMPT_MAP[prompt_key] | |
llm_chain = LLMChain(llm=llm, prompt=prompt) | |
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs) | |
def _chain_type(self) -> str: | |
return "hyde_chain" | |