nttwt1597 commited on
Commit
a8be9e1
·
verified ·
1 Parent(s): bd5d11b

RAG + feedback update

Browse files
Files changed (1) hide show
  1. app.py +161 -124
app.py CHANGED
@@ -1,86 +1,182 @@
1
  import os
2
- token=os.environ['token']
3
- # token_r=os.environ['token_r']
4
- # token_w=os.environ['token_w']
 
 
 
 
 
 
 
 
 
 
5
  import torch
 
 
 
 
6
  import gradio as gr
7
- from unsloth import FastLanguageModel
8
- from peft import PeftConfig, PeftModel, get_peft_model
9
- from transformers import pipeline, TextIteratorStreamer
10
- from threading import Thread
11
-
12
- # For getting tokenizer()
13
- model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
14
- peft_model_adapter_id = "nttwt1597/test_v2_cancer_v4_checkpoint2900"
15
-
16
- model, tokenizer = FastLanguageModel.from_pretrained(
17
- model_name = model_id,
18
- max_seq_length = 4096,
19
- dtype = None,
20
- load_in_4bit = True,
21
  )
22
- model.load_adapter(peft_model_adapter_id, token=token)
23
 
24
- terminators = [
25
  tokenizer.eos_token_id,
26
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
27
  ]
28
 
29
- FastLanguageModel.for_inference(model)
 
 
 
 
 
30
 
31
- criteria_prompt = """Based on the provided instructions and clinical trial information, generate the eligibility criteria for the study.
32
 
33
- ### Instruction:
34
- As a clinical researcher, generate comprehensive eligibility criteria to be used in clinical research based on the given clinical trial information. Ensure the criteria are clear, specific, and suitable for a clinical research setting.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- ### Clinical trial information:
37
- {}
 
38
 
39
- ### Eligibility criteria:
40
- {}"""
41
 
42
- def format_prompt(text):
43
- return criteria_prompt.format(text, "")
44
 
45
- def run_model_on_text(text):
46
- prompt = format_prompt(text)
47
- inputs = tokenizer(prompt, return_tensors='pt')
 
 
 
48
 
49
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
50
 
51
- generation_kwargs = dict(inputs, streamer=streamer,eos_token_id=terminators, max_new_tokens=1024, repetition_penalty=1.175,)
52
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
53
- thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- generated_text = ""
56
- for new_text in streamer:
57
- generated_text += new_text
58
- yield generated_text
59
 
60
  place_holder = f"""Study Objectives
61
- The purpose of this study is to evaluate the safety, tolerance and efficacy of Liposomal Paclitaxel With Nedaplatin as First-line in patients with Advanced or Recurrent Esophageal Carcinoma
62
 
63
- Conditions: Esophageal Carcinoma
64
 
65
- Intervention / Treatment:
66
- DRUG: Liposomal Paclitaxel,
67
- DRUG: Nedaplatin
68
 
69
- Location: China
70
 
71
- Study Design and Phases
72
- Study Type: INTERVENTIONAL
73
- Phase: PHASE2 Primary Purpose:
74
- TREATMENT Allocation: NA
75
  Interventional Model: SINGLE_GROUP Masking: NONE
76
  """
77
 
78
  prefilled_value = """Study Objectives
79
- [Brief Summary] and/or [Detailed Description]
80
 
81
  Conditions: [Disease]
82
 
83
- Intervention / Treatment
84
  [DRUGs]
85
 
86
  Location
@@ -90,92 +186,33 @@ Study Design and Phases
90
  Study Type:
91
  Phase:
92
  Primary Purpose:
93
- Allocation:
94
  Interventional Model:
95
  Masking:"""
96
 
97
-
98
- # hf_writer = gr.HuggingFaceDatasetSaver("ravistech/criteria-feedback-demo",token, private=True)
99
- # with gr.Blocks() as demo:
100
- # with gr.Row():
101
- # with gr.Column():
102
- # prompt_box = gr.Textbox(
103
- # label="Research Information",
104
- # placeholder=place_holder,
105
- # value=prefilled_value,
106
- # lines=10)
107
- # submit_button = gr.Button("Generate")
108
- # with gr.Column():
109
- # output_box = gr.Textbox(
110
- # label="Eligiblecriteria Criteria",
111
- # lines=21,
112
- # interactive=False)
113
- # with gr.Row():
114
- # with gr.Column():
115
- # feedback_box = gr.Textbox(label="Enter your feedback here...", lines=3, interactive=True)
116
- # feedback_button = gr.Button("Submit Feedback")
117
- # status_text = gr.Textbox(label="Status", lines=1, interactive=False)
118
-
119
- # submit_button.click(
120
- # run_model_on_text,
121
- # inputs=prompt_box,
122
- # outputs=output_box
123
- # )
124
-
125
- # def submit_feedback(prompt, generated_text, feedback):
126
- # data = {
127
- # "prompt": prompt,
128
- # "generated_text": generated_text,
129
- # "feedback": feedback
130
- # }
131
- # hf_writer.flag(data)
132
- # return "Feedback submitted."
133
-
134
- # feedback_button.click(
135
- # submit_feedback,
136
- # inputs=[prompt_box, output_box, feedback_box],
137
- # outputs=status_text
138
- # )
139
-
140
- # feedback_button.click(
141
- # hf_writer.flag([prompt_box,output_box,feedback_box]),
142
- # # lambda *args: hf_writer.flag(args),
143
- # inputs=[prompt_box, output_box, feedback_box],
144
- # outputs=status_text,
145
- # )
146
-
147
- # gr.Interface(lambda x:x, "text", "text", allow_flagging="manual", flagging_callback=hf_writer)
148
-
149
- # feedback_button.click(
150
- # save_feedback,
151
- # inputs=[prompt_box, output_box, feedback_box],
152
- # outputs=status_text
153
- # )
154
-
155
- # demo.launch()
156
-
157
-
158
-
159
- #----------------------------------
160
  prompt_box = gr.Textbox(
161
- lines=25,
162
  label="Research Information",
163
- placeholder=place_holder,
164
  value=prefilled_value,
165
  )
166
 
167
  output_box = gr.Textbox(
168
- lines=25,
169
  label="Eligiblecriteria Criteria",
170
  )
171
 
172
  demo = gr.Interface(
173
- fn=run_model_on_text,
174
  inputs=prompt_box,
175
  outputs=output_box,
176
- # allow_flagging="manual",
177
- # flagging_options=["incorrect", "inappropriate", "appropriate"],
178
- # flagging_callback=hf_writer
 
 
179
  )
180
 
181
  demo.queue(max_size=20).launch(debug=True, share=True)
 
1
  import os
2
+ token_r=os.environ['token_r']
3
+ token_w=os.environ['token_w']
4
+ token_w_feedback=os.environ['token_w_feedback']
5
+
6
+ from llama_index.core import Settings
7
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
8
+
9
+ !pip install -qU llama-index-vector-stores-elasticsearch llama-index-embeddings-huggingface llama-index
10
+ from llama_index.vector_stores.elasticsearch import ElasticsearchStore
11
+
12
+ from llama_index.core.query_engine import CitationQueryEngine
13
+ from llama_index.core import VectorStoreIndex
14
+
15
  import torch
16
+ from transformers import AutoTokenizer
17
+ from llama_index.llms.huggingface import HuggingFaceLLM
18
+ from transformers import BitsAndBytesConfig
19
+
20
  import gradio as gr
21
+
22
+ model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained(
25
+ model_name,
26
+ token=token_r,
 
 
 
 
 
 
 
 
27
  )
 
28
 
29
+ stopping_ids = [
30
  tokenizer.eos_token_id,
31
+ tokenizer.convert_tokens_to_ids("<|eot_id|>"),
32
  ]
33
 
34
+ quantization_config = BitsAndBytesConfig(
35
+ load_in_4bit=True,
36
+ bnb_4bit_compute_dtype=torch.float16,
37
+ bnb_4bit_quant_type="nf4",
38
+ bnb_4bit_use_double_quant=True,
39
+ )
40
 
 
41
 
42
+ # Get the model
43
+ llm = HuggingFaceLLM(
44
+ model_name="meta-llama/Meta-Llama-3-8B-Instruct",
45
+ tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct",
46
+ model_kwargs={
47
+ "token": token_r,
48
+ "quantization_config": quantization_config
49
+ },
50
+ context_window=8191,
51
+ max_new_tokens=2048,
52
+ generate_kwargs={
53
+ # "do_sample": True,
54
+ # "temperature": 0.1,
55
+ # "top_p": 0.9,
56
+ 'repetition_penalty': 1,
57
+ },
58
+ stopping_ids=stopping_ids,
59
+ )
60
 
61
+ # bge embedding model
62
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
63
+ Settings.embed_model = embed_model
64
 
65
+ # Llama-3-8B-Instruct model
66
+ Settings.llm = llm
67
 
68
+ # Get data from Elasticsearch
 
69
 
70
+ es_vector_store = ElasticsearchStore(
71
+ index_name="train_criteria_index",
72
+ es_cloud_id=es_cloud_id
73
+ es_user="elastic",
74
+ es_password=es_password
75
+ )
76
 
77
+ index_es = VectorStoreIndex.from_vector_store(es_vector_store)
78
 
79
+ query_engine_get_study = CitationQueryEngine.from_args(
80
+ index_es,
81
+ similarity_top_k=10,
82
+ citation_chunk_size=2048,
83
+ verbose=True,
84
+ )
85
+
86
+ def get_prompt(text):
87
+ studies_response = query_engine_get_study.query(f"""
88
+ Based on the provided instructions and clinical trial information, What are the eligibility criteria based on the given clinical trial information.
89
+ Ensure the studies are relevant and have similar study information. Prioritize the following topics when finding related studies:
90
+ 1. Conditions
91
+ 2. Intervention/Treatment
92
+ 3. Study Objectives
93
+ 4. Study Design and Phases
94
+
95
+
96
+ ### Clinical Trial Information:
97
+ {text}
98
+ """)
99
+
100
+ study_ref=[]
101
+ metadata_list = []
102
+ for source in studies_response.source_nodes:
103
+ ref = source.node.get_text()
104
+ study_ref.append(ref)
105
+ meta_data = source.node.get_metadata_str()
106
+ metadata_list.append(meta_data)
107
+
108
+ # return
109
+ criteria_response = llm.stream_complete(f"""
110
+ Based on the provided instructions and clinical trial information, generate the eligibility criteria for the study.
111
+
112
+ ## Instruction:
113
+ You are a clinical researcher able to generate new comprehensive eligibility criteria for clinical research based on the given clinical trial information.
114
+ By analyze clinical trial information, delimited by ### Clinical Trial Information, and the information from the following papers, delimited by ### Related data, by choose the suitable criteria and optimize for the given clinical trial information for more precise new eligibility criteria generation.
115
+ And please giving us an NCT IDs and study names using the following papers, delimited by ### Reference Papers.
116
+ The pattern of the output is delimited by ### Pattern of the output.
117
+ Ensure the criteria are clear, specific, and suitable for a clinical research information.
118
+
119
+ Prioritize the following topics from the clinical trial information
120
+ 1. Conditions
121
+ 2. Intervention/Treatment
122
+ 3. Study Objectives
123
+ 4. Study Design and Phase
124
+
125
+ ### Clinical Trial Information
126
+ {text}
127
+
128
+ ### Related data
129
+ {study_ref}
130
+
131
+ ### Reference Papers
132
+ {metadata_list}
133
+
134
+ ### Pattern of the output
135
+ Inclusion Criteria
136
+ 1.
137
+ 2.
138
+
139
+ Exclusion Criteria
140
+ 1.
141
+ 2.
142
+
143
+ Reference Papers
144
+ 1. NCT ID:
145
+ Study Name:
146
+ 2. NCT ID:
147
+ Study Name:
148
+ 3. NCT ID:
149
+ Study Name:
150
+ """)
151
+
152
+ for chunk in criteria_response:
153
+ yield chunk
154
 
 
 
 
 
155
 
156
  place_holder = f"""Study Objectives
157
+ The purpose of this study is to evaluate the safety, tolerance and efficacy of Liposomal Paclitaxel With Nedaplatin as First-line in patients with Advanced or Recurrent Esophageal Carcinoma
158
 
159
+ Conditions: Esophageal Carcinoma
160
 
161
+ Intervention / Treatment:
162
+ DRUG: Liposomal Paclitaxel,
163
+ DRUG: Nedaplatin
164
 
165
+ Location: China
166
 
167
+ Study Design and Phases
168
+ Study Type: INTERVENTIONAL
169
+ Phase: PHASE2 Primary Purpose:
170
+ TREATMENT Allocation: NA
171
  Interventional Model: SINGLE_GROUP Masking: NONE
172
  """
173
 
174
  prefilled_value = """Study Objectives
175
+ [Brief Summary and/or Detailed Description]
176
 
177
  Conditions: [Disease]
178
 
179
+ Intervention / Treatment
180
  [DRUGs]
181
 
182
  Location
 
186
  Study Type:
187
  Phase:
188
  Primary Purpose:
189
+ Allocation:
190
  Interventional Model:
191
  Masking:"""
192
 
193
+ hf_writer = gr.HuggingFaceDatasetSaver(hf_token=token_w_feedback, dataset_name="nttwt1597/criteria-feedback-demo-1", private=True)
194
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  prompt_box = gr.Textbox(
196
+ lines=10,
197
  label="Research Information",
198
+ # placeholder=place_holder,
199
  value=prefilled_value,
200
  )
201
 
202
  output_box = gr.Textbox(
203
+ lines=10,
204
  label="Eligiblecriteria Criteria",
205
  )
206
 
207
  demo = gr.Interface(
208
+ fn=get_prompt,
209
  inputs=prompt_box,
210
  outputs=output_box,
211
+ # allow_flagging='auto',
212
+ allow_flagging="manual",
213
+ flagging_options=["appropriate","inappropriate","incorrect",],
214
+ flagging_callback=hf_writer,
215
+ # live=True
216
  )
217
 
218
  demo.queue(max_size=20).launch(debug=True, share=True)