Spaces:
Runtime error
Runtime error
"""Running custom embedding models on self-hosted remote hardware.""" | |
from typing import Any, Callable, List | |
from pydantic import BaseModel, Extra | |
from langchain.embeddings.base import Embeddings | |
from langchain.llms import SelfHostedPipeline | |
def _embed_documents(pipeline: Any, *args: Any, **kwargs: Any) -> List[List[float]]: | |
"""Inference function to send to the remote hardware. | |
Accepts a sentence_transformer model_id and | |
returns a list of embeddings for each document in the batch. | |
""" | |
return pipeline(*args, **kwargs) | |
class SelfHostedEmbeddings(SelfHostedPipeline, Embeddings, BaseModel): | |
"""Runs custom embedding models on self-hosted remote hardware. | |
Supported hardware includes auto-launched instances on AWS, GCP, Azure, | |
and Lambda, as well as servers specified | |
by IP address and SSH credentials (such as on-prem, or another | |
cloud like Paperspace, Coreweave, etc.). | |
To use, you should have the ``runhouse`` python package installed. | |
Example using a model load function: | |
.. code-block:: python | |
from langchain.embeddings import SelfHostedEmbeddings | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
import runhouse as rh | |
gpu = rh.cluster(name="rh-a10x", instance_type="A100:1") | |
def get_pipeline(): | |
model_id = "facebook/bart-large" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained(model_id) | |
return pipeline("feature-extraction", model=model, tokenizer=tokenizer) | |
embeddings = SelfHostedEmbeddings( | |
model_load_fn=get_pipeline, | |
hardware=gpu | |
model_reqs=["./", "torch", "transformers"], | |
) | |
Example passing in a pipeline path: | |
.. code-block:: python | |
from langchain.embeddings import SelfHostedHFEmbeddings | |
import runhouse as rh | |
from transformers import pipeline | |
gpu = rh.cluster(name="rh-a10x", instance_type="A100:1") | |
pipeline = pipeline(model="bert-base-uncased", task="feature-extraction") | |
rh.blob(pickle.dumps(pipeline), | |
path="models/pipeline.pkl").save().to(gpu, path="models") | |
embeddings = SelfHostedHFEmbeddings.from_pipeline( | |
pipeline="models/pipeline.pkl", | |
hardware=gpu, | |
model_reqs=["./", "torch", "transformers"], | |
) | |
""" | |
inference_fn: Callable = _embed_documents | |
"""Inference function to extract the embeddings on the remote hardware.""" | |
inference_kwargs: Any = None | |
"""Any kwargs to pass to the model's inference function.""" | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.forbid | |
def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
"""Compute doc embeddings using a HuggingFace transformer model. | |
Args: | |
texts: The list of texts to embed.s | |
Returns: | |
List of embeddings, one for each text. | |
""" | |
texts = list(map(lambda x: x.replace("\n", " "), texts)) | |
embeddings = self.client(self.pipeline_ref, texts) | |
if not isinstance(embeddings, list): | |
return embeddings.tolist() | |
return embeddings | |
def embed_query(self, text: str) -> List[float]: | |
"""Compute query embeddings using a HuggingFace transformer model. | |
Args: | |
text: The text to embed. | |
Returns: | |
Embeddings for the text. | |
""" | |
text = text.replace("\n", " ") | |
embeddings = self.client(self.pipeline_ref, text) | |
if not isinstance(embeddings, list): | |
return embeddings.tolist() | |
return embeddings | |