Spaces:
Sleeping
Sleeping
import os | |
import getpass | |
import pinecone | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
from tqdm import tqdm | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.vectorstores import Pinecone | |
from langchain.document_loaders import TextLoader, DirectoryLoader | |
# os.environ["PINECONE_API_KEY"] = getpass.getpass("Pinecone API Key:") | |
# os.environ["PINECONE_ENV"] = getpass.getpass("Pinecone Environment:") | |
#gcp-starter | |
print("Downloading model") | |
print() | |
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") | |
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") | |
print() | |
print("Models downloaded") | |
pinecone.init( | |
api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io | |
environment=os.getenv("PINECONE_ENV"), # next to api key in console | |
) | |
def get_bert_embeddings(sentence): | |
embeddings = [] | |
input_ids = tokenizer.encode(sentence, return_tensors="pt") | |
with torch.no_grad(): | |
output = model(input_ids) | |
embedding = output.last_hidden_state[:,0,:].numpy().tolist() | |
# embeddings.append((f"doc-{doc_no}-seg{i}", embedding, {"meta_data": text_input[i]})) | |
# embeddings.append((f"doc-{doc_no}-seg{i}", embedding, {"meta_data": text_input[i]})) | |
return embedding | |
def fetch_top_k(input_data, history): | |
top_k=3 | |
index_name = "ophtal-knowledge-base" | |
index = pinecone.Index(index_name) | |
emb = get_bert_embeddings(input_data) | |
query = index.query( | |
vector=emb, | |
top_k=top_k, | |
include_values=True | |
) | |
id_list = [] | |
for i in query['matches']: | |
# print(i) | |
id_list.append(i['id']) | |
fetched_data = index.fetch(id_list) | |
# topk_list = [] | |
text_list = [] | |
source_list = [] | |
for id_ in id_list: | |
text = index.fetch(id_list)['vectors'][id_]['metadata']['text'] | |
source = index.fetch(id_list)['vectors'][id_]['metadata']['source'] | |
text_list.append(text) | |
source_list.append(source) | |
print(text_list) | |
# return "hello" | |
return '\n'.join(text_list) | |
# if __name__ == "__main__": | |
# neet = "Which of the following is true regarding Mittendorf dot?\nA. Glial tissue projecting from optic disc\nB. Obliterated vessel running forward into the vitreous\nC. Associated with posterior polar cataract\nD. Commonest congenital anomaly of hyaloid system" | |
# text, source = fetch_top_k(neet) | |
# # print(text) |