drvai-rag / rag.py
aakash0017's picture
Upload folder using huggingface_hub
b5f8985
import os
import getpass
import pinecone
from transformers import AutoTokenizer, AutoModel
import torch
from tqdm import tqdm
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Pinecone
from langchain.document_loaders import TextLoader, DirectoryLoader
# os.environ["PINECONE_API_KEY"] = getpass.getpass("Pinecone API Key:")
# os.environ["PINECONE_ENV"] = getpass.getpass("Pinecone Environment:")
#gcp-starter
print("Downloading model")
print()
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
print()
print("Models downloaded")
pinecone.init(
api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
environment=os.getenv("PINECONE_ENV"), # next to api key in console
)
def get_bert_embeddings(sentence):
embeddings = []
input_ids = tokenizer.encode(sentence, return_tensors="pt")
with torch.no_grad():
output = model(input_ids)
embedding = output.last_hidden_state[:,0,:].numpy().tolist()
# embeddings.append((f"doc-{doc_no}-seg{i}", embedding, {"meta_data": text_input[i]}))
# embeddings.append((f"doc-{doc_no}-seg{i}", embedding, {"meta_data": text_input[i]}))
return embedding
def fetch_top_k(input_data, history):
top_k=3
index_name = "ophtal-knowledge-base"
index = pinecone.Index(index_name)
emb = get_bert_embeddings(input_data)
query = index.query(
vector=emb,
top_k=top_k,
include_values=True
)
id_list = []
for i in query['matches']:
# print(i)
id_list.append(i['id'])
fetched_data = index.fetch(id_list)
# topk_list = []
text_list = []
source_list = []
for id_ in id_list:
text = index.fetch(id_list)['vectors'][id_]['metadata']['text']
source = index.fetch(id_list)['vectors'][id_]['metadata']['source']
text_list.append(text)
source_list.append(source)
print(text_list)
# return "hello"
return '\n'.join(text_list)
# if __name__ == "__main__":
# neet = "Which of the following is true regarding Mittendorf dot?\nA. Glial tissue projecting from optic disc\nB. Obliterated vessel running forward into the vitreous\nC. Associated with posterior polar cataract\nD. Commonest congenital anomaly of hyaloid system"
# text, source = fetch_top_k(neet)
# # print(text)