File size: 2,099 Bytes
7e02cc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import google.generativeai as genai
from dotenv import load_dotenv
import os
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_cohere import CohereEmbeddings
from langchain_openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings

load_dotenv()

class Embeddings:

    '''
        google, models/embedding-001
        openai, openai
        cohere, cohere
        hf, all-MiniLM-L6-v2
        hf, BAAI/bge-large-en-v1.5
        hf, Alibaba-NLP/gte-large-en-v1.5, True
        ...
        ...
    '''

    def __init__(self, emb, model, trust_remote=False, normalize = False):
        self.emb=emb
        self.model = model
        self.trust_remote = trust_remote
        self.normalize = normalize
        self.embedding = self.get_embedding()
        self.seq_len = self.get_emb_len()

    def get_emb_len(self):
        return len(self.embedding.embed_query('hi how are you'))
    
    def google_embedding(self):
        genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
        embeddings = GoogleGenerativeAIEmbeddings(model = self.model)
        return embeddings
    
    def openai_embedding(self):
        embeddings_model = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
        return embeddings_model

    def cohere_embedding(self):
        embeddings_model = CohereEmbeddings(cohere_api_key=os.getenv("COHERE_API_KEY"))
        return embeddings_model

    def hf_embedding(self):
        model_args = {'trust_remote_code': True} if self.trust_remote else {}
        encode_args = {'normalize_embeddings': True} if self.normalize else {}
        embedding = HuggingFaceEmbeddings(model_name=self.model, model_kwargs = model_args, encode_kwargs = encode_args)
        return embedding

    def get_embedding(self):
        if self.emb == 'google':
            return self.google_embedding()
        elif self.emb == 'openai':
            return self.openai_embedding()
        elif self.emb == 'cohere':
            return self.cohere_embedding()
        elif self.emb == 'hf':
            return self.hf_embedding()