File size: 2,113 Bytes
613c93d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import traceback

import h5py
import numpy as np
from loguru import logger
from sentence_transformers import SentenceTransformer


class EmbeddingsManager:
    def __init__(self, model_name, bible_version, texts, embeddings_cache_dir) -> None:

        # Load embeddings model
        self.model = SentenceTransformer(model_name)

        # Load or generate embeddings baseed on the corpus
        sanitized_model_name = model_name.replace("\\", "-").replace("/", "-")
        self.cache_filename = f"{bible_version}_{sanitized_model_name}.h5"
        self.emb_cache_filepath = os.path.join(
            embeddings_cache_dir, self.cache_filename
        )

        # Load embeddings if it exists
        try:
            with h5py.File(self.emb_cache_filepath, "r") as h:
                self.embeddings = np.array(h["embeddings"])
        except Exception:
            traceback.print_exc()
            # If it doesn't, generate embeddings and save to a file
            logger.info(
                f"Generating embeddings and saving to {self.emb_cache_filepath}"
            )
            self.embeddings = self.model.encode(texts)
            with h5py.File(self.emb_cache_filepath, "w") as f:
                f.create_dataset("embeddings", data=self.embeddings)

        # Create a look-up dict to quickly retrieve embeddings of texts
        self.text_emb_dict = {}
        for text, embedding in zip(texts, self.embeddings):
            self.text_emb_dict[text] = embedding

        logger.info(
            f"Successfully loaded {model_name} embeddings for {bible_version} from {self.emb_cache_filepath}."
        )

    def get_embeddings(self, texts):
        embeddings = []
        for text in texts:
            if text not in self.text_emb_dict:
                self.text_emb_dict[text] = self.model.encode([text])[0]
            embeddings.append(self.text_emb_dict[text])
        return embeddings

    def __str__(self):
        return self.emb_cache_filepath


def score_semantic_similarity(query, texts_df):
    """Returns copy of text_df with semantic similarity scores."""
    pass