File size: 3,707 Bytes
5758a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6c98a6
 
5758a81
defa4f7
5758a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472cc0a
5758a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pandas as pd
import os
from tqdm.auto import tqdm
import pinecone
from sentence_transformers import SentenceTransformer
import torch

class PineconeIndex:

    def __init__(self):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.sm = SentenceTransformer('all-MiniLM-L6-v2', device=device)
        self.index_name = 'semantic-search-fast-med'
        self.index = None

    def init_pinecone(self):

        index_name = self.index_name
        sentence_model = self.sm

        # get api key from app.pinecone.io
        PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY', 'None')
        
        # find your environment next to the api key in pinecone console
        PINECONE_ENV = os.environ.get('PINECONE_ENV', "us-west4-gcp")

        pinecone.init(
            api_key=PINECONE_API_KEY,
            environment=PINECONE_ENV
        )

#         pinecone.delete_index(index_name)

        # only create index if it doesn't exist
        if index_name not in pinecone.list_indexes():
            pinecone.create_index(
                name=index_name,
                dimension=sentence_model.get_sentence_embedding_dimension(),
                metric='cosine'
            )

        # now connect to the index
        self.index = pinecone.GRPCIndex(index_name)
        return self.index

    def build_index(self):

        if self.index is None:
            index = self.init_pinecone()
        else:
            index = self.index

        if index.describe_index_stats()['total_vector_count']:
            "Index already built"
            return

        sentence_model = self.sm

        x = pd.read_excel('Diseases_data_W.xlsx')

        question_dict = {'About': 'What is {}?', 'Symptoms': 'What are symptoms of {}?',
                         'Causes': 'What are causes of {}?',
                         'Diagnosis': 'What are diagnosis for {}?', 'Risk Factors': 'What are the risk factors for {}?',
                         'Treatment Options': 'What are the treatment options for {}?',
                         'Prognosis and Complications': 'What are the prognosis and complications?'}
        context = []
        disease_list = []

        for i in range(len(x)):
            disease = x.iloc[i, 0]
            if disease.strip().lower() in disease_list:
                continue

            disease_list.append(disease.strip().lower())

            conditions = x.iloc[i, 1:].dropna().index
            answers = x.iloc[i, 1:].dropna()

            for cond in conditions:
                context.append(f"{question_dict[cond].format(disease)}\n\n{answers[cond]}")

        batch_size = 128
        for i in tqdm(range(0, len(context), batch_size)):
            # find end of batch
            i_end = min(i + batch_size, len(context))
            # create IDs batch
            ids = [str(x) for x in range(i, i_end)]
            # create metadata batch
            metadatas = [{'text': text} for text in context[i:i_end]]
            # create embeddings
            xc = sentence_model.encode(context[i:i_end])
            # create records list for upsert
            records = zip(ids, xc, metadatas)
            # upsert to Pinecone
            index.upsert(vectors=records)

        # check number of records in the index
        index.describe_index_stats()

    def search(self, query: str = "medicines for fever"):

        sentence_model = self.sm

        if self.index is None:
            self.build_index()

        index = self.index

        # create the query vector
        xq = sentence_model.encode(query).tolist()

        # now query
        xc = index.query(xq, top_k = 3, include_metadata = True)
        
        return xc