|
from typing import List |
|
|
|
import os |
|
from langchain_core.embeddings import Embeddings |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
def get_roberta_embeddings(sentences: List[str]): |
|
""" |
|
Get features of Korean input texts w/ BM-K/KoSimCSE-roberta. |
|
Returns: |
|
List[List[int]] of dimension 768 |
|
""" |
|
|
|
model = AutoModel.from_pretrained("BM-K/KoSimCSE-roberta") |
|
tokenizer = AutoTokenizer.from_pretrained("BM-K/KoSimCSE-roberta") |
|
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt") |
|
embeddings, _ = model(**inputs, return_dict=False) |
|
ls = [] |
|
for embedding in embeddings: |
|
vector = embedding[0].detach().numpy().tolist() |
|
ls.append(vector) |
|
return ls |
|
|
|
|
|
class KorRobertaEmbeddings(Embeddings): |
|
"""Feature Extraction w/ BM-K/KoSimCSE-roberta""" |
|
|
|
dimension = 768 |
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]: |
|
return get_roberta_embeddings(texts) |
|
|
|
def embed_query(self, text: str) -> List[float]: |
|
return get_roberta_embeddings([text])[0] |
|
|