File size: 1,741 Bytes
28c2a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862bece
28c2a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

from base_class import Embedding_Model
import pickle
from sentence_transformers import SentenceTransformer

from openai.embeddings_utils import (
    get_embedding,
)


class HuggingfaceSentenceTransformerModel(Embedding_Model):
    EMBEDDING_MODEL = "distiluse-base-multilingual-cased-v2"

    def __init__(self, model_name=EMBEDDING_MODEL) -> None:
        super().__init__(model_name)
        
        self.model = SentenceTransformer(model_name, cache_folder="/app/ckpt/")

    def __call__(self, text) -> None:
        return self.model.encode(text)


class OpenAIEmbeddingModel(Embedding_Model):
    # constants
    EMBEDDING_MODEL = "text-embedding-ada-002"
    # establish a cache of embeddings to avoid recomputing
    # cache is a dict of tuples (text, model) -> embedding, saved as a pickle file

    def __init__(self, model_name=EMBEDDING_MODEL) -> None:
        super().__init__(model_name)
        self.model_name = model_name

    # define a function to retrieve embeddings from the cache if present, and otherwise request via the API
    def embedding_from_string(self,
                              string: str,
                              ) -> list:
        """Return embedding of given string, using a cache to avoid recomputing."""
        model = self.model_name
        if (string, model) not in self.embedding_cache.keys():
            self.embedding_cache[(string, model)] = get_embedding(
                string, model)
            with open(self.embedding_cache_path, "wb") as embedding_cache_file:
                pickle.dump(self.embedding_cache, embedding_cache_file)
        return self.embedding_cache[(string, model)]

    def __call__(self, text) -> None:
        return self.embedding_from_string(text)