Updated rag.py
Browse files
rag.py
CHANGED
@@ -10,10 +10,9 @@ class RAG:
|
|
10 |
def __init__(self):
|
11 |
self.model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
12 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
-
local_dir = "llm_models/"
|
14 |
|
15 |
self.embedding_model_name = "all-mpnet-base-v2"
|
16 |
-
self.embeddings_filename = "
|
17 |
|
18 |
self.data_pd = pd.read_csv(self.embeddings_filename)
|
19 |
self.data_dict = pd.read_csv(self.embeddings_filename).to_dict(orient='records')
|
@@ -23,11 +22,9 @@ class RAG:
|
|
23 |
# Embedding model
|
24 |
self.embedding_model = SentenceTransformer(model_name_or_path = self.embedding_model_name,device = self.device)
|
25 |
# Tokenizer
|
26 |
-
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_id
|
27 |
-
cache_dir = local_dir)
|
28 |
# LLM
|
29 |
-
self.llm_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=self.model_id
|
30 |
-
cache_dir = local_dir).to(self.device)
|
31 |
|
32 |
def get_embeddings(self) -> list:
|
33 |
"""Returns the embeddings from the csv file"""
|
|
|
10 |
def __init__(self):
|
11 |
self.model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
12 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
13 |
|
14 |
self.embedding_model_name = "all-mpnet-base-v2"
|
15 |
+
self.embeddings_filename = "embeddings.csv"
|
16 |
|
17 |
self.data_pd = pd.read_csv(self.embeddings_filename)
|
18 |
self.data_dict = pd.read_csv(self.embeddings_filename).to_dict(orient='records')
|
|
|
22 |
# Embedding model
|
23 |
self.embedding_model = SentenceTransformer(model_name_or_path = self.embedding_model_name,device = self.device)
|
24 |
# Tokenizer
|
25 |
+
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=self.model_id)
|
|
|
26 |
# LLM
|
27 |
+
self.llm_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=self.model_id).to(self.device)
|
|
|
28 |
|
29 |
def get_embeddings(self) -> list:
|
30 |
"""Returns the embeddings from the csv file"""
|