xpsychted commited on
Commit
4e44d93
1 Parent(s): e8cbecb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -118
app.py CHANGED
@@ -1,167 +1,160 @@
1
  import gradio as gr
2
- import pandas as pd
3
  import os
4
- from tqdm.auto import tqdm
5
- import pinecone
6
- from sentence_transformers import SentenceTransformer
7
- import torch
8
-
9
- from transformers import AutoModel, AutoConfig
10
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
11
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
12
 
13
  # !pip install transformers accelerate
14
  # !pip install -qU pinecone-client[grpc] sentence-transformers
15
  # !pip install gradio
16
 
17
- class PineconeIndex:
18
 
19
- def __init__(self):
20
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
 
22
- self.sm = SentenceTransformer('all-MiniLM-L6-v2', device=device)
23
- self.index_name = 'semantic-search-fast-med'
24
- self.index = None
25
 
26
- def init_pinecone(self):
27
 
28
- index_name = self.index_name
29
- sentence_model = self.sm
30
 
31
- # get api key from app.pinecone.io
32
- PINECONE_API_KEY = "b97d5759-dd39-428b-a1fd-ed30f3ba74ee" # os.environ.get('PINECONE_API_KEY') or 'PINECONE_API_KEY'
33
- # find your environment next to the api key in pinecone console
34
- PINECONE_ENV = "us-west4-gcp" # os.environ.get('PINECONE_ENV') or 'PINECONE_ENV'
35
 
36
- pinecone.init(
37
- api_key=PINECONE_API_KEY,
38
- environment=PINECONE_ENV
39
- )
40
 
41
- # pinecone.delete_index(index_name)
42
 
43
- # only create index if it doesn't exist
44
- if index_name not in pinecone.list_indexes():
45
- pinecone.create_index(
46
- name=index_name,
47
- dimension=sentence_model.get_sentence_embedding_dimension(),
48
- metric='cosine'
49
- )
50
 
51
- # now connect to the index
52
- self.index = pinecone.GRPCIndex(index_name)
53
- return self.index
54
 
55
- def build_index(self):
56
 
57
- if self.index is None:
58
- index = self.init_pinecone()
59
- else:
60
- index = self.index
61
 
62
- if index.describe_index_stats()['total_vector_count']:
63
- "Index already built"
64
- return
65
 
66
- sentence_model = self.sm
67
 
68
- x = pd.read_excel('/kaggle/input/drug-p/Diseases_data_W.xlsx')
69
 
70
- question_dict = {'About': 'What is {}?', 'Symptoms': 'What are symptoms of {}?',
71
- 'Causes': 'What are causes of {}?',
72
- 'Diagnosis': 'What are diagnosis for {}?', 'Risk Factors': 'What are the risk factors for {}?',
73
- 'Treatment Options': 'What are the treatment options for {}?',
74
- 'Prognosis and Complications': 'What are the prognosis and complications?'}
75
- context = []
76
- disease_list = []
77
 
78
- for i in range(len(x)):
79
- disease = x.iloc[i, 0]
80
- if disease.strip().lower() in disease_list:
81
- continue
82
 
83
- disease_list.append(disease.strip().lower())
84
 
85
- conditions = x.iloc[i, 1:].dropna().index
86
- answers = x.iloc[i, 1:].dropna()
87
 
88
- for cond in conditions:
89
- context.append(f"{question_dict[cond].format(disease)}\n\n{answers[cond]}")
90
 
91
- batch_size = 128
92
- for i in tqdm(range(0, len(context), batch_size)):
93
- # find end of batch
94
- i_end = min(i + batch_size, len(context))
95
- # create IDs batch
96
- ids = [str(x) for x in range(i, i_end)]
97
- # create metadata batch
98
- metadatas = [{'text': text} for text in context[i:i_end]]
99
- # create embeddings
100
- xc = sentence_model.encode(context[i:i_end])
101
- # create records list for upsert
102
- records = zip(ids, xc, metadatas)
103
- # upsert to Pinecone
104
- index.upsert(vectors=records)
105
 
106
- # check number of records in the index
107
- index.describe_index_stats()
108
 
109
- def search(self, query: str = "medicines for fever"):
110
 
111
- sentence_model = self.sm
112
 
113
- if self.index is None:
114
- self.build_index()
115
 
116
- index = self.index
117
 
118
- # create the query vector
119
- xq = sentence_model.encode(query).tolist()
120
 
121
- # now query
122
- xc = index.query(xq, top_k = 3, include_metadata = True)
123
 
124
- return xc
125
 
126
 
127
- class QAModel():
128
- def __init__(self, checkpoint="google/flan-t5-xl"):
129
- self.checkpoint = checkpoint
130
- self.tmpdir = f"{self.checkpoint.split('/')[-1]}-sharded"
131
 
132
- def store_sharded_model(self):
133
- tmpdir = self.tmpdir
134
 
135
- checkpoint = self.checkpoint
136
 
137
- if not os.path.exists(tmpdir):
138
- os.mkdir(tmpdir)
139
- print(f"Directory created - {tmpdir}")
140
- model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
141
- print(f"Model loaded - {checkpoint}")
142
- model.save_pretrained(tmpdir, max_shard_size="200MB")
143
-
144
- def load_sharded_model(self):
145
- tmpdir = self.tmpdir
146
- if not os.path.exists(tmpdir):
147
- self.store_sharded_model()
148
 
149
- checkpoint = self.checkpoint
150
 
151
 
152
- config = AutoConfig.from_pretrained(checkpoint)
153
 
154
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
155
- with init_empty_weights():
156
- model = AutoModelForSeq2SeqLM.from_config(config)
157
- # model = AutoModelForSeq2SeqLM.from_pretrained(tmpdir)
158
 
159
- model = load_checkpoint_and_dispatch(model, checkpoint=tmpdir, device_map="auto")
160
- return model, tokenizer
161
 
162
- def query_model(self, model, tokenizer, query):
163
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
164
- return tokenizer.batch_decode(model.generate(**tokenizer(query, return_tensors='pt').to(device)), skip_special_tokens=True)[0]
165
 
166
  PI = PineconeIndex()
167
  PI.build_index()
 
1
  import gradio as gr
 
2
  import os
3
+ from pinecone_integration import PineconeIndex
4
+ from qa_model import QAModel
 
 
 
 
 
 
5
 
6
  # !pip install transformers accelerate
7
  # !pip install -qU pinecone-client[grpc] sentence-transformers
8
  # !pip install gradio
9
 
10
+ # class PineconeIndex:
11
 
12
+ # def __init__(self):
13
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
 
15
+ # self.sm = SentenceTransformer('all-MiniLM-L6-v2', device=device)
16
+ # self.index_name = 'semantic-search-fast-med'
17
+ # self.index = None
18
 
19
+ # def init_pinecone(self):
20
 
21
+ # index_name = self.index_name
22
+ # sentence_model = self.sm
23
 
24
+ # # get api key from app.pinecone.io
25
+ # PINECONE_API_KEY = "b97d5759-dd39-428b-a1fd-ed30f3ba74ee" # os.environ.get('PINECONE_API_KEY') or 'PINECONE_API_KEY'
26
+ # # find your environment next to the api key in pinecone console
27
+ # PINECONE_ENV = "us-west4-gcp" # os.environ.get('PINECONE_ENV') or 'PINECONE_ENV'
28
 
29
+ # pinecone.init(
30
+ # api_key=PINECONE_API_KEY,
31
+ # environment=PINECONE_ENV
32
+ # )
33
 
34
+ # # pinecone.delete_index(index_name)
35
 
36
+ # # only create index if it doesn't exist
37
+ # if index_name not in pinecone.list_indexes():
38
+ # pinecone.create_index(
39
+ # name=index_name,
40
+ # dimension=sentence_model.get_sentence_embedding_dimension(),
41
+ # metric='cosine'
42
+ # )
43
 
44
+ # # now connect to the index
45
+ # self.index = pinecone.GRPCIndex(index_name)
46
+ # return self.index
47
 
48
+ # def build_index(self):
49
 
50
+ # if self.index is None:
51
+ # index = self.init_pinecone()
52
+ # else:
53
+ # index = self.index
54
 
55
+ # if index.describe_index_stats()['total_vector_count']:
56
+ # "Index already built"
57
+ # return
58
 
59
+ # sentence_model = self.sm
60
 
61
+ # x = pd.read_excel('/kaggle/input/drug-p/Diseases_data_W.xlsx')
62
 
63
+ # question_dict = {'About': 'What is {}?', 'Symptoms': 'What are symptoms of {}?',
64
+ # 'Causes': 'What are causes of {}?',
65
+ # 'Diagnosis': 'What are diagnosis for {}?', 'Risk Factors': 'What are the risk factors for {}?',
66
+ # 'Treatment Options': 'What are the treatment options for {}?',
67
+ # 'Prognosis and Complications': 'What are the prognosis and complications?'}
68
+ # context = []
69
+ # disease_list = []
70
 
71
+ # for i in range(len(x)):
72
+ # disease = x.iloc[i, 0]
73
+ # if disease.strip().lower() in disease_list:
74
+ # continue
75
 
76
+ # disease_list.append(disease.strip().lower())
77
 
78
+ # conditions = x.iloc[i, 1:].dropna().index
79
+ # answers = x.iloc[i, 1:].dropna()
80
 
81
+ # for cond in conditions:
82
+ # context.append(f"{question_dict[cond].format(disease)}\n\n{answers[cond]}")
83
 
84
+ # batch_size = 128
85
+ # for i in tqdm(range(0, len(context), batch_size)):
86
+ # # find end of batch
87
+ # i_end = min(i + batch_size, len(context))
88
+ # # create IDs batch
89
+ # ids = [str(x) for x in range(i, i_end)]
90
+ # # create metadata batch
91
+ # metadatas = [{'text': text} for text in context[i:i_end]]
92
+ # # create embeddings
93
+ # xc = sentence_model.encode(context[i:i_end])
94
+ # # create records list for upsert
95
+ # records = zip(ids, xc, metadatas)
96
+ # # upsert to Pinecone
97
+ # index.upsert(vectors=records)
98
 
99
+ # # check number of records in the index
100
+ # index.describe_index_stats()
101
 
102
+ # def search(self, query: str = "medicines for fever"):
103
 
104
+ # sentence_model = self.sm
105
 
106
+ # if self.index is None:
107
+ # self.build_index()
108
 
109
+ # index = self.index
110
 
111
+ # # create the query vector
112
+ # xq = sentence_model.encode(query).tolist()
113
 
114
+ # # now query
115
+ # xc = index.query(xq, top_k = 3, include_metadata = True)
116
 
117
+ # return xc
118
 
119
 
120
+ # class QAModel():
121
+ # def __init__(self, checkpoint="google/flan-t5-xl"):
122
+ # self.checkpoint = checkpoint
123
+ # self.tmpdir = f"{self.checkpoint.split('/')[-1]}-sharded"
124
 
125
+ # def store_sharded_model(self):
126
+ # tmpdir = self.tmpdir
127
 
128
+ # checkpoint = self.checkpoint
129
 
130
+ # if not os.path.exists(tmpdir):
131
+ # os.mkdir(tmpdir)
132
+ # print(f"Directory created - {tmpdir}")
133
+ # model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
134
+ # print(f"Model loaded - {checkpoint}")
135
+ # model.save_pretrained(tmpdir, max_shard_size="200MB")
136
+
137
+ # def load_sharded_model(self):
138
+ # tmpdir = self.tmpdir
139
+ # if not os.path.exists(tmpdir):
140
+ # self.store_sharded_model()
141
 
142
+ # checkpoint = self.checkpoint
143
 
144
 
145
+ # config = AutoConfig.from_pretrained(checkpoint)
146
 
147
+ # tokenizer = AutoTokenizer.from_pretrained(checkpoint)
148
+ # with init_empty_weights():
149
+ # model = AutoModelForSeq2SeqLM.from_config(config)
150
+ # # model = AutoModelForSeq2SeqLM.from_pretrained(tmpdir)
151
 
152
+ # model = load_checkpoint_and_dispatch(model, checkpoint=tmpdir, device_map="auto")
153
+ # return model, tokenizer
154
 
155
+ # def query_model(self, model, tokenizer, query):
156
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
157
+ # return tokenizer.batch_decode(model.generate(**tokenizer(query, return_tensors='pt').to(device)), skip_special_tokens=True)[0]
158
 
159
  PI = PineconeIndex()
160
  PI.build_index()