File size: 960 Bytes
36c7297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Load model directly
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from sentence_transformers import SentenceTransformer
from transformers import Trainer
import torch
import torch.nn.functional as F

class ModelWrapper():
     def __init__(self, location = "./models/deepset/tinyroberta-squad"):
        self.model_location = location
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_location)
        self.model_qa = AutoModelForQuestionAnswering.from_pretrained(self.model_location)
        self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

     def get_embeddings(self, text, isDocument):
        if isDocument:
            text = text.split(".")
        embeddings = self.embedding_model.encode(text)

        if isDocument:
            embeddings = sum(embeddings).reshape(1,-1)
        else:
            embeddings = embeddings.reshape(1,-1)
        
        return embeddings