ScholarBot / src /embedding.py
vinny4's picture
initial commit
9c37331
from typing import Union, List
from langchain.embeddings import HuggingFaceEmbeddings
class EmbeddingModel:
"""
A flexible embedding model wrapper supporting multiple backend models.
"""
def __init__(self, model_type: str = "huggingface", model_name: str = "all-MiniLM-L6-v2"):
self.model_type = model_type
self.model_name = model_name
self.model = self._load_model()
def _load_model(self):
if self.model_type == "huggingface":
return HuggingFaceEmbeddings(model_name=self.model_name)
# Implementation for other model types can be added here
else:
raise ValueError(f"Unsupported model type: {self.model_type}")
def embed(self, text: Union[str, List[str]]):
"""
Generate embeddings for the given text.
:param text: A string or list of strings.
:return: A list of embeddings.
"""
if self.model_type == "huggingface":
if isinstance(text, list):
return [self.model.embed_query(t) for t in text]
return self.model.embed_query(text)
elif self.model_type == "sentence_transformers":
return self.model.encode(text, convert_to_tensor=True).tolist()
else:
raise NotImplementedError(f"Embedding for {self.model_type} is not implemented.")