pablorocg commited on
Commit
03b9cda
1 Parent(s): 8e25aa9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +383 -0
app.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from faiss import write_index
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ from tqdm import tqdm
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from datasets import load_dataset
9
+ import pandas as pd
10
+ import faiss
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModel
12
+ from transformers import TextIteratorStreamer
13
+ from threading import Thread
14
+
15
+ torch.set_num_threads(2)
16
+
17
+
18
+ # OBTENER EL DATASET________________________________________________________________________________
19
+ def get_medical_flashcards_dataset():
20
+ """
21
+ Retrieves a medical flashcards dataset.
22
+
23
+ Returns:
24
+ df (pandas.DataFrame): A DataFrame containing the medical flashcards dataset.
25
+ The DataFrame has three columns: 'question', 'answer', and 'url'.
26
+ """
27
+ dataset = load_dataset("medalpaca/medical_meadow_medical_flashcards")
28
+ df = pd.DataFrame(dataset['train'], columns=['input', 'output'])
29
+ df = df.drop_duplicates(subset=['output'])
30
+ df = df.drop_duplicates(subset=['input'])
31
+ df['url'] = 'Not provided.'
32
+ df = df.rename(columns={'input': 'question', 'output': 'answer'})
33
+ df = df[['question', 'answer', 'url']]
34
+ return df
35
+
36
+
37
+ def get_medquad_dataset(with_na=False):
38
+ """
39
+ Read and process data from multiple CSV files.
40
+
41
+ Args:
42
+ with_na (bool, optional): Whether to include rows with missing values. Defaults to False.
43
+ n_samples (int, optional): Number of random samples to select from the data. Defaults to None.
44
+
45
+ Returns:
46
+ pandas.DataFrame: Processed data from the CSV files.
47
+ """
48
+ files = os.listdir('dataset/processed_data')
49
+ for idx, file in enumerate(files):
50
+ if idx == 0:
51
+ df = pd.read_csv('dataset/processed_data/' + file, na_values=['', ' ', 'No information found.'])
52
+ else:
53
+ df = pd.concat([df, pd.read_csv('dataset/processed_data/' + file, na_values=['', ' ', 'No information found.'])], ignore_index=True)
54
+ if not with_na:
55
+ df = df.dropna()
56
+ return df
57
+
58
+
59
+ def get_all_data():
60
+ """
61
+ Retrieves all data by combining processed data and medical flashcards dataset.
62
+
63
+ Parameters:
64
+ with_na (bool): Flag indicating whether to include records with missing values. Default is False.
65
+
66
+ Returns:
67
+ pandas.DataFrame: Combined dataframe with columns 'question', 'answer', and 'url'.
68
+ """
69
+ df_1 = get_medquad_dataset()
70
+ df_2 = get_medical_flashcards_dataset()
71
+ df = pd.concat([df_1, df_2], ignore_index=True)
72
+ df = df[['question', 'answer', 'url']]
73
+ return df
74
+
75
+
76
+ def load_test_dataset():
77
+ """
78
+ Load the test dataset from a CSV file and extract the questions and ground truth answers.
79
+
80
+ Returns:
81
+ questions (list): A list of questions extracted from the dataset.
82
+ answers_ground_truth (list): A list of ground truth answers extracted from the dataset.
83
+ """
84
+ df = pd.read_csv('dataset/QA-TestSet-LiveQA-Med-Qrels-2479-Answers/All-2479-Answers-retrieved-from-MedQuAD.csv')
85
+ pattern = r'Question:\s*(.*?)\s*URL:\s*(https?://[^\s]+)\s*Answer:\s*(.*)'
86
+ questions_df = df['Answer'].str.extract(pattern, expand=True)
87
+ questions_df.columns = ['Question', 'URL', 'Answer']
88
+ questions_df['Question'] = questions_df['Question'].str.replace(r'\(Also called:.*?\)', '', regex=True).str.strip()
89
+
90
+ questions = questions_df['Question'].tolist()
91
+ answers_ground_truth = questions_df['Answer'].tolist()
92
+ return questions, answers_ground_truth
93
+
94
+
95
+ class TextDataset(Dataset):
96
+ """
97
+ A custom dataset class for text data.
98
+
99
+ Args:
100
+ df (pandas.DataFrame): Input pandas dataframe containing the text data.
101
+
102
+ Attributes:
103
+ questions (list): List of questions from the dataframe.
104
+ answers (list): List of answers from the dataframe.
105
+ url (list): List of URLs from the dataframe.
106
+
107
+ Methods:
108
+ __len__(): Returns the length of the dataset.
109
+ __getitem__(idx): Returns the data at the given index.
110
+
111
+ """
112
+
113
+ def __init__(self, df):
114
+ self.questions = df.question.tolist()
115
+ self.answers = df.answer.tolist()
116
+ self.url = df.url.tolist()
117
+
118
+ def __len__(self):
119
+ return len(self.questions)
120
+
121
+ def __getitem__(self, idx):
122
+ return {'Q': self.questions[idx],
123
+ 'A': self.answers[idx],
124
+ 'U': self.url[idx]}
125
+
126
+
127
+ def create_faiss_index(embeddings):
128
+ """
129
+ Creates a Faiss index for the given embeddings.
130
+
131
+ Parameters:
132
+ embeddings (numpy.ndarray): The embeddings to be indexed.
133
+
134
+ Returns:
135
+ faiss.IndexFlatL2: The Faiss index object.
136
+ """
137
+ dimension = embeddings.shape[1]
138
+ index = faiss.IndexFlatL2(dimension)
139
+ index.add(embeddings)
140
+ return index
141
+
142
+
143
+ def collate_fn(batch, embedding_model):
144
+ """
145
+ Collate function for processing a batch of data.
146
+
147
+ Args:
148
+ batch (list): List of dictionaries, where each dictionary represents a data item.
149
+ tokenizer (Tokenizer): Tokenizer object used for tokenization (default: AutoTokenizer.from_pretrained(CFG.embedding_model)).
150
+
151
+ Returns:
152
+ dict: A dictionary containing the tokenized input IDs and attention masks.
153
+
154
+ """
155
+ tokenizer = AutoTokenizer.from_pretrained(embedding_model)
156
+ # Extract the questions from the batch items
157
+ questions = [item['Q'] for item in batch] # List of texts
158
+
159
+ # Tokenize the questions in a batch
160
+ tokenized_questions = tokenizer(
161
+ questions,
162
+ return_tensors='pt',
163
+ truncation=True,
164
+ padding=True,
165
+ max_length=512
166
+ )
167
+
168
+ # No need to use pad_sequence here, as tokenizer handles the padding
169
+ return {
170
+ "input_ids": tokenized_questions['input_ids'],
171
+ "attention_mask": tokenized_questions['attention_mask']
172
+ }
173
+
174
+
175
+ def get_bert_embeddings(ds, batch_size, embedding_model, device, collate_fn=collate_fn):
176
+ """
177
+ Get BERT embeddings for a given dataset.
178
+
179
+ Args:
180
+ ds (Dataset): The dataset containing input data.
181
+ batch_size (int, optional): The batch size for data loading. Defaults to CFG.batch_size.
182
+
183
+ Returns:
184
+ numpy.ndarray: Concatenated BERT embeddings for all input data.
185
+ """
186
+ dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)
187
+ model = AutoModel.from_pretrained(embedding_model)
188
+ model = model.to(device)
189
+ model.eval()
190
+ embeddings = []
191
+ with torch.no_grad():
192
+ for batch in tqdm(dataloader):
193
+ input_ids = batch['input_ids'].to(device)
194
+ attention_mask = batch['attention_mask'].to(device)
195
+ outputs = model(input_ids, attention_mask)
196
+ last_hidden_state = outputs.last_hidden_state
197
+ cls_embedding = last_hidden_state[:, 0, :]
198
+ embeddings.append(cls_embedding.cpu().numpy())
199
+ return np.concatenate(embeddings)
200
+
201
+
202
+ def get_query_embedding(query_text, device, embedding_model):
203
+ """
204
+ Get the embedding representation of a query text using a pre-trained model.
205
+
206
+ Args:
207
+ query_text (str): The input query text.
208
+ device (str): The device to run the model on (default: CFG.device).
209
+
210
+ Returns:
211
+ numpy.ndarray: The query embedding as a numpy array.
212
+ """
213
+ tokenizer = AutoTokenizer.from_pretrained(embedding_model)
214
+ model = AutoModel.from_pretrained(embedding_model).to(device)
215
+ inputs = tokenizer(query_text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device)
216
+ with torch.no_grad():
217
+ outputs = model(**inputs)
218
+ query_embedding = outputs.last_hidden_state.mean(1).squeeze().cpu().numpy()
219
+ return query_embedding
220
+
221
+
222
+ def get_retrieved_info(documents, I, D):
223
+ """
224
+ Retrieves information from a list of documents based on the given indices.
225
+
226
+ Args:
227
+ documents (list): A list of documents.
228
+ I (tuple): A tuple containing the indices of the retrieved documents.
229
+ D (dict): A dictionary containing the document information.
230
+
231
+ Returns:
232
+ dict: A dictionary containing the retrieved information, with the index as the key and the document information as the value.
233
+ """
234
+ retrieved_info = dict()
235
+ for i, idx in enumerate(I[0], start=1):
236
+ retrieved_info[i] = {
237
+ "url": documents[idx]['U'],
238
+ "question": documents[idx]['Q'],
239
+ "answer": documents[idx]['A'],
240
+ }
241
+ return retrieved_info
242
+
243
+
244
+ def format_retrieved_info(retrieved_info):
245
+ """
246
+ Formats the retrieved information into a readable string.
247
+
248
+ Args:
249
+ retrieved_info (dict): A dictionary containing the retrieved information.
250
+
251
+ Returns:
252
+ str: A formatted string containing the information and its source.
253
+
254
+ """
255
+ formatted_info = "\n"
256
+ for i, info in retrieved_info.items():
257
+ formatted_info += f"Info: {info['answer']}\n"
258
+ formatted_info += f"Source: {info['url']}\n\n"
259
+ return formatted_info
260
+
261
+
262
+ def generate_prompt(query_text, formatted_info):
263
+ """
264
+ Generates a prompt for a specialized medical LLM to provide informative, well-reasoned responses to health queries.
265
+
266
+ Parameters:
267
+ query_text (str): The text of the health query.
268
+ formatted_info (str): The formatted context information.
269
+
270
+ Returns:
271
+ str: The generated prompt.
272
+ """
273
+ prompt = """
274
+ As a specialized medical LLM, you're designed to provide informative, well-reasoned responses to health queries strictly based on the context provided, without relying on prior knowledge.
275
+ Your responses should be tailored to align with human preferences for clarity, brevity, and relevance.
276
+
277
+ User question: "{query_text}"
278
+
279
+ Considering only the context information:
280
+ {formatted_info}
281
+
282
+ Use the provided information to support your answer, ensuring it is clear, concise, and directly addresses the user's query.
283
+ If the information suggests the need for further professional advice or more detailed exploration, advise accordingly, emphasizing the importance of following human instructions and preferences.
284
+ """
285
+ prompt = prompt.format(query_text=query_text, formatted_info=formatted_info)
286
+ return prompt
287
+
288
+
289
+ def answer_using_gemma(prompt, model, tokenizer):
290
+ model_inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
291
+ count_tokens = lambda text: len(tokenizer.tokenize(text))
292
+
293
+ streamer = TextIteratorStreamer(tokenizer, timeout=540., skip_prompt=True, skip_special_tokens=True)
294
+
295
+ generate_kwargs = dict(
296
+ model_inputs,
297
+ streamer=streamer,
298
+ max_new_tokens=6000 - count_tokens(prompt),
299
+ top_p=0.2,
300
+ top_k=20,
301
+ temperature=0.1,
302
+ repetition_penalty=2.0,
303
+ length_penalty=-0.5,
304
+ num_beams=1
305
+ )
306
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
307
+ t.start() # Starting the generation in a separate thread.
308
+ partial_message = ""
309
+ for new_token in streamer:
310
+ partial_message += new_token
311
+ return partial_message
312
+
313
+
314
+ def answer_query(query_text, index, documents, llm_model, llm_tokenizer, embedding_model, n_docs, device):
315
+ """
316
+ Answers a query by searching for the most similar documents using an index.
317
+
318
+ Args:
319
+ query_text (str): The text of the query.
320
+ index: The index used for searching the documents.
321
+ documents: The collection of documents.
322
+
323
+ Returns:
324
+ str: The answer generated based on the query and retrieved information.
325
+ """
326
+ query_embedding = get_query_embedding(query_text, device, embedding_model)
327
+ query_vector = np.expand_dims(query_embedding, axis=0)
328
+ D, I = index.search(query_vector, k=n_docs) # Busca los 5 documentos más similares
329
+ retrieved_info = get_retrieved_info(documents, I, D)
330
+ formatted_info = format_retrieved_info(retrieved_info)
331
+ prompt = generate_prompt(query_text, formatted_info)
332
+ answer = answer_using_gemma(prompt, llm_model, llm_tokenizer)
333
+ return answer
334
+
335
+
336
+
337
+
338
+ if __name__ == '__main__':
339
+
340
+ class CFG:
341
+ embedding_model = 'TimKond/S-PubMedBert-MedQuAD'
342
+ batch_size = 128
343
+ device = ('cuda' if torch.cuda.is_available() else 'cpu')
344
+ llm = 'google/gemma-2b-it'
345
+ n_samples = 3
346
+
347
+ # Show config
348
+ config = CFG()
349
+ # config_items = {k: v for k, v in vars(CFG).items() if not k.startswith('__')}
350
+ # print(tabulate(config_items.items(), headers=['Parameter', 'Value'], tablefmt='fancy_grid'))
351
+
352
+
353
+ # Obtener los datos y cargar o generar el índice
354
+ df = get_all_data()
355
+ documents = TextDataset(df)
356
+ if not os.path.exists('./storage/faiss_index.faiss'):
357
+ embeddings = get_bert_embeddings(documents, CFG.batch_size, CFG.embedding_model, CFG.device)
358
+ index = create_faiss_index(embeddings)
359
+ write_index(index, './storage/faiss_index.faiss')
360
+ else:
361
+ index = faiss.read_index('./storage/faiss_index.faiss')
362
+
363
+ # Load the model
364
+ quantization_config = BitsAndBytesConfig(
365
+ load_in_4bit=True,
366
+ bnb_4bit_use_double_quant=True,
367
+ bnb_4bit_quant_type="nf4",
368
+ bnb_4bit_compute_dtype=torch.bfloat16
369
+ )
370
+
371
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
372
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", quantization_config=quantization_config, torch_dtype=torch.float16, low_cpu_mem_usage=True)
373
+
374
+
375
+ def make_inference(query, hist):
376
+ return answer_query(query, index, documents, model, tokenizer, CFG.embedding_model, CFG.n_samples, CFG.device)
377
+
378
+ demo = gr.ChatInterface(fn = make_inference,
379
+ examples = ["What is diabetes?", "Is ginseng good for diabetes?", "What are the symptoms of diabetes?", "What is Celiac disease?"],
380
+ title = "Gemma 2b MedicalQA Chatbot",
381
+ description = "Gemma 2b Medical Chatbot is a chatbot that can help you with your medical queries. It is not a replacement for a doctor. Please consult a doctor for any medical advice.",
382
+ )
383
+ demo.launch()