File size: 1,122 Bytes
9e88bc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from typing import List
import os
from langchain_core.embeddings import Embeddings
from transformers import AutoModel, AutoTokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def get_roberta_embeddings(sentences: List[str]):
"""
Get features of Korean input texts w/ BM-K/KoSimCSE-roberta.
Returns:
List[List[int]] of dimension 768
"""
model = AutoModel.from_pretrained("BM-K/KoSimCSE-roberta")
tokenizer = AutoTokenizer.from_pretrained("BM-K/KoSimCSE-roberta")
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
embeddings, _ = model(**inputs, return_dict=False)
ls = []
for embedding in embeddings:
vector = embedding[0].detach().numpy().tolist()
ls.append(vector)
return ls
class KorRobertaEmbeddings(Embeddings):
"""Feature Extraction w/ BM-K/KoSimCSE-roberta"""
dimension = 768
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return get_roberta_embeddings(texts)
def embed_query(self, text: str) -> List[float]:
return get_roberta_embeddings([text])[0]
|