File size: 1,741 Bytes
28c2a3d 862bece 28c2a3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from base_class import Embedding_Model
import pickle
from sentence_transformers import SentenceTransformer
from openai.embeddings_utils import (
get_embedding,
)
class HuggingfaceSentenceTransformerModel(Embedding_Model):
EMBEDDING_MODEL = "distiluse-base-multilingual-cased-v2"
def __init__(self, model_name=EMBEDDING_MODEL) -> None:
super().__init__(model_name)
self.model = SentenceTransformer(model_name, cache_folder="/app/ckpt/")
def __call__(self, text) -> None:
return self.model.encode(text)
class OpenAIEmbeddingModel(Embedding_Model):
# constants
EMBEDDING_MODEL = "text-embedding-ada-002"
# establish a cache of embeddings to avoid recomputing
# cache is a dict of tuples (text, model) -> embedding, saved as a pickle file
def __init__(self, model_name=EMBEDDING_MODEL) -> None:
super().__init__(model_name)
self.model_name = model_name
# define a function to retrieve embeddings from the cache if present, and otherwise request via the API
def embedding_from_string(self,
string: str,
) -> list:
"""Return embedding of given string, using a cache to avoid recomputing."""
model = self.model_name
if (string, model) not in self.embedding_cache.keys():
self.embedding_cache[(string, model)] = get_embedding(
string, model)
with open(self.embedding_cache_path, "wb") as embedding_cache_file:
pickle.dump(self.embedding_cache, embedding_cache_file)
return self.embedding_cache[(string, model)]
def __call__(self, text) -> None:
return self.embedding_from_string(text)
|