from typing import Iterable, Iterator from langchain.docstore.document import Document from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings # TODO check HuggingFaceInstructEmbeddings class HuggingFaceTextEmbedding: def __init__(self) -> None: model_name = "sentence-transformers/all-mpnet-base-v2" model_kwargs = {'device': 'cpu'} encode_kwargs = {'normalize_embeddings': False} self.model = HuggingFaceEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs ) def embed_documents(self, docs: Iterable[Document]) -> Iterator[Document]: embeddings = self.model.embed_documents(docs) return embeddings # class HuggingFaceInferenceAPITextEmbedding: # def __init__(self) -> None: # pass # def embed_documents(self, docs: Iterable[Document]) -> Iterator[Document]: # embeddings = HuggingFaceInferenceAPIEmbeddings( # api_key=inference_api_key, # model_name="sentence-transformers/all-MiniLM-l6-v2" # ) # chunks = embeddings.embed_documents(docs) # return chunks