xpsychted commited on
Commit
063a814
1 Parent(s): e6c98a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -153
app.py CHANGED
@@ -3,158 +3,6 @@ import os
3
  from pinecone_integration import PineconeIndex
4
  from qa_model import QAModel
5
 
6
- # !pip install transformers accelerate
7
- # !pip install -qU pinecone-client[grpc] sentence-transformers
8
- # !pip install gradio
9
-
10
- # class PineconeIndex:
11
-
12
- # def __init__(self):
13
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
-
15
- # self.sm = SentenceTransformer('all-MiniLM-L6-v2', device=device)
16
- # self.index_name = 'semantic-search-fast-med'
17
- # self.index = None
18
-
19
- # def init_pinecone(self):
20
-
21
- # index_name = self.index_name
22
- # sentence_model = self.sm
23
-
24
- # # get api key from app.pinecone.io
25
- # PINECONE_API_KEY = "b97d5759-dd39-428b-a1fd-ed30f3ba74ee" # os.environ.get('PINECONE_API_KEY') or 'PINECONE_API_KEY'
26
- # # find your environment next to the api key in pinecone console
27
- # PINECONE_ENV = "us-west4-gcp" # os.environ.get('PINECONE_ENV') or 'PINECONE_ENV'
28
-
29
- # pinecone.init(
30
- # api_key=PINECONE_API_KEY,
31
- # environment=PINECONE_ENV
32
- # )
33
-
34
- # # pinecone.delete_index(index_name)
35
-
36
- # # only create index if it doesn't exist
37
- # if index_name not in pinecone.list_indexes():
38
- # pinecone.create_index(
39
- # name=index_name,
40
- # dimension=sentence_model.get_sentence_embedding_dimension(),
41
- # metric='cosine'
42
- # )
43
-
44
- # # now connect to the index
45
- # self.index = pinecone.GRPCIndex(index_name)
46
- # return self.index
47
-
48
- # def build_index(self):
49
-
50
- # if self.index is None:
51
- # index = self.init_pinecone()
52
- # else:
53
- # index = self.index
54
-
55
- # if index.describe_index_stats()['total_vector_count']:
56
- # "Index already built"
57
- # return
58
-
59
- # sentence_model = self.sm
60
-
61
- # x = pd.read_excel('/kaggle/input/drug-p/Diseases_data_W.xlsx')
62
-
63
- # question_dict = {'About': 'What is {}?', 'Symptoms': 'What are symptoms of {}?',
64
- # 'Causes': 'What are causes of {}?',
65
- # 'Diagnosis': 'What are diagnosis for {}?', 'Risk Factors': 'What are the risk factors for {}?',
66
- # 'Treatment Options': 'What are the treatment options for {}?',
67
- # 'Prognosis and Complications': 'What are the prognosis and complications?'}
68
- # context = []
69
- # disease_list = []
70
-
71
- # for i in range(len(x)):
72
- # disease = x.iloc[i, 0]
73
- # if disease.strip().lower() in disease_list:
74
- # continue
75
-
76
- # disease_list.append(disease.strip().lower())
77
-
78
- # conditions = x.iloc[i, 1:].dropna().index
79
- # answers = x.iloc[i, 1:].dropna()
80
-
81
- # for cond in conditions:
82
- # context.append(f"{question_dict[cond].format(disease)}\n\n{answers[cond]}")
83
-
84
- # batch_size = 128
85
- # for i in tqdm(range(0, len(context), batch_size)):
86
- # # find end of batch
87
- # i_end = min(i + batch_size, len(context))
88
- # # create IDs batch
89
- # ids = [str(x) for x in range(i, i_end)]
90
- # # create metadata batch
91
- # metadatas = [{'text': text} for text in context[i:i_end]]
92
- # # create embeddings
93
- # xc = sentence_model.encode(context[i:i_end])
94
- # # create records list for upsert
95
- # records = zip(ids, xc, metadatas)
96
- # # upsert to Pinecone
97
- # index.upsert(vectors=records)
98
-
99
- # # check number of records in the index
100
- # index.describe_index_stats()
101
-
102
- # def search(self, query: str = "medicines for fever"):
103
-
104
- # sentence_model = self.sm
105
-
106
- # if self.index is None:
107
- # self.build_index()
108
-
109
- # index = self.index
110
-
111
- # # create the query vector
112
- # xq = sentence_model.encode(query).tolist()
113
-
114
- # # now query
115
- # xc = index.query(xq, top_k = 3, include_metadata = True)
116
-
117
- # return xc
118
-
119
-
120
- # class QAModel():
121
- # def __init__(self, checkpoint="google/flan-t5-xl"):
122
- # self.checkpoint = checkpoint
123
- # self.tmpdir = f"{self.checkpoint.split('/')[-1]}-sharded"
124
-
125
- # def store_sharded_model(self):
126
- # tmpdir = self.tmpdir
127
-
128
- # checkpoint = self.checkpoint
129
-
130
- # if not os.path.exists(tmpdir):
131
- # os.mkdir(tmpdir)
132
- # print(f"Directory created - {tmpdir}")
133
- # model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
134
- # print(f"Model loaded - {checkpoint}")
135
- # model.save_pretrained(tmpdir, max_shard_size="200MB")
136
-
137
- # def load_sharded_model(self):
138
- # tmpdir = self.tmpdir
139
- # if not os.path.exists(tmpdir):
140
- # self.store_sharded_model()
141
-
142
- # checkpoint = self.checkpoint
143
-
144
-
145
- # config = AutoConfig.from_pretrained(checkpoint)
146
-
147
- # tokenizer = AutoTokenizer.from_pretrained(checkpoint)
148
- # with init_empty_weights():
149
- # model = AutoModelForSeq2SeqLM.from_config(config)
150
- # # model = AutoModelForSeq2SeqLM.from_pretrained(tmpdir)
151
-
152
- # model = load_checkpoint_and_dispatch(model, checkpoint=tmpdir, device_map="auto")
153
- # return model, tokenizer
154
-
155
- # def query_model(self, model, tokenizer, query):
156
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
157
- # return tokenizer.batch_decode(model.generate(**tokenizer(query, return_tensors='pt').to(device)), skip_special_tokens=True)[0]
158
 
159
  PI = PineconeIndex()
160
  PI.build_index()
@@ -192,7 +40,7 @@ demo = gr.Interface(
192
  ],
193
  cache_examples=True,
194
  title="MedQA assistant",
195
- #description="MedQA assistant"
196
  )
197
 
198
  demo.launch()
 
3
  from pinecone_integration import PineconeIndex
4
  from qa_model import QAModel
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  PI = PineconeIndex()
8
  PI.build_index()
 
40
  ],
41
  cache_examples=True,
42
  title="MedQA assistant",
43
+ description='Check out the repository at: https://github.com/anandshah98/MedQA',
44
  )
45
 
46
  demo.launch()