|
|
|
from base_class import Embedding_Model |
|
import pickle |
|
from sentence_transformers import SentenceTransformer |
|
|
|
from openai.embeddings_utils import ( |
|
get_embedding, |
|
) |
|
|
|
|
|
class HuggingfaceSentenceTransformerModel(Embedding_Model): |
|
EMBEDDING_MODEL = "distiluse-base-multilingual-cased-v2" |
|
|
|
def __init__(self, model_name=EMBEDDING_MODEL) -> None: |
|
super().__init__(model_name) |
|
|
|
self.model = SentenceTransformer(model_name, cache_folder="/app/ckpt/") |
|
|
|
def __call__(self, text) -> None: |
|
return self.model.encode(text) |
|
|
|
|
|
class OpenAIEmbeddingModel(Embedding_Model): |
|
|
|
EMBEDDING_MODEL = "text-embedding-ada-002" |
|
|
|
|
|
|
|
def __init__(self, model_name=EMBEDDING_MODEL) -> None: |
|
super().__init__(model_name) |
|
self.model_name = model_name |
|
|
|
|
|
def embedding_from_string(self, |
|
string: str, |
|
) -> list: |
|
"""Return embedding of given string, using a cache to avoid recomputing.""" |
|
model = self.model_name |
|
if (string, model) not in self.embedding_cache.keys(): |
|
self.embedding_cache[(string, model)] = get_embedding( |
|
string, model) |
|
with open(self.embedding_cache_path, "wb") as embedding_cache_file: |
|
pickle.dump(self.embedding_cache, embedding_cache_file) |
|
return self.embedding_cache[(string, model)] |
|
|
|
def __call__(self, text) -> None: |
|
return self.embedding_from_string(text) |
|
|