1-ARIjitS commited on
Commit
86f6253
1 Parent(s): 52ee7a9

tagging included

Browse files
Files changed (1) hide show
  1. llm_res.py +149 -54
llm_res.py CHANGED
@@ -44,23 +44,106 @@ def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str
44
  return clinical_records
45
 
46
 
47
- def process_json_data_for_llm(data):
48
-
49
- # Define the fields you want to keep
50
- fields_to_keep = [
51
- "class_of_organization",
52
- "title",
53
- "overallStatus",
54
- "descriptionModule",
55
- "conditions",
56
- "interventions",
57
- "outcomesModule",
58
- "eligibilityModule",
59
- ]
60
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # Iterate through the dictionary and keep only the desired fields
62
  filtered_data = []
63
- for item in data:
64
  try:
65
  organization_name = item["protocolSection"]["identificationModule"][
66
  "organization"
@@ -132,22 +215,24 @@ def process_json_data_for_llm(data):
132
  "eligibility": eligibility,
133
  }
134
  filtered_data.append(filtered_item)
135
-
136
- # for ele in filtered_data:
137
- # print(ele)
138
-
139
 
140
  def get_short_summary_out_of_json_files(data_json):
141
- prompt_template = """ You are an expert clinician working on the analysis of reports of clinical trials.
 
 
 
142
 
143
- # Task
144
- You will be given a set of descriptions of clinical trials. Your job is to come up with a short summary (100-200 words) of the descriptions of the clinical trials. Your users are clinical researchers who are experts in medicine, so you should be technical and specific, including scientific terms. Always be faithful to the original information written in the reports.
145
 
146
- To write your summary, you will need to read the following examples, labeled as "Report 1", "Report 2", and so on. Your answer should be a single paragraph (100-200 words) that summarizes the general content of all the reports.
147
 
148
- {text}
149
 
150
- General summary:"""
 
 
151
 
152
  prompt = PromptTemplate.from_template(prompt_template)
153
 
@@ -178,18 +263,31 @@ General summary:"""
178
  print(f"Combined descriptions: {combined_descriptions}")
179
 
180
  result = stuff_chain.run(combined_descriptions)
181
- print(f"Result: {result}")
182
 
183
  return result
184
 
 
 
 
 
 
 
 
 
185
 
186
- def taggingTemplate():
 
 
 
 
 
187
  class Classification(BaseModel):
188
- description: str = Field(
189
- description="text description grouping all the clinical trials using briefDescription and detailedDescription keys"
190
- )
191
  project_title: list = Field(
192
- description="Extract the project title of all the clinical trials"
193
  )
194
  status: list = Field(
195
  description="Extract the status of all the clinical trials"
@@ -207,43 +305,45 @@ def taggingTemplate():
207
  # eligibility: list = Field(
208
  # description="get the eligibilityCriteria grouping all the clinical trials"
209
  # )
210
- # healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
211
- # minimum_age: list = Field(
212
- # description="get the minimum age from each experiment"
213
- # )
214
- # maximum_age: list = Field(
215
- # description="get the maximum age from each experiment"
216
- # )
217
- # gender: list = Field(description="get the gender from each experiment")
218
 
219
  def get_dict(self):
220
  return {
221
- "summary": self.description,
222
  "project_title": self.project_title,
223
  "status": self.status,
224
- "keywords": self.keywords,
225
  "interventions": self.interventions,
226
  "primary_outcomes": self.primary_outcomes,
227
  # "secondary_outcomes": self.secondary_outcomes,
228
- "eligibility": self.eligibility,
229
- # "healthy_volunteers": self.healthy_volunteers,
230
  "minimum_age": self.minimum_age,
231
  "maximum_age": self.maximum_age,
232
- "gender": self.gender,
233
  }
234
 
235
  # LLM
236
  llm = ChatOpenAI(
237
  temperature=0.6,
238
- model="gpt-4",
239
  openai_api_key=os.environ["OPENAI_API_KEY"],
240
  ).with_structured_output(Classification)
241
 
242
- stuff_chain = StuffDocumentsChain(llm_chain=llm, document_variable_name="text")
243
 
244
- # tagging_chain = prompt_template | llm
245
 
246
- # return tagging_chain
 
 
 
247
 
248
 
249
  # clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
@@ -252,10 +352,5 @@ def taggingTemplate():
252
  # with open('data.json', 'w') as f:
253
  # json.dump(clinical_record_info, f, indent=4)
254
 
255
- # tagging_chain = llm_config()
256
-
257
 
258
- def process_dictionaty_with_llm_to_generate_response(json_contents):
259
- processed_data = process_json_data_for_llm(json_contents)
260
- # res = tagging_chain.invoke({"input": processed_data})
261
- # return res
 
44
  return clinical_records
45
 
46
 
47
+ # # def process_json_data_for_llm(data):
48
+
49
+ # # Define the fields you want to keep
50
+ # fields_to_keep = [
51
+ # "class_of_organization",
52
+ # "title",
53
+ # "overallStatus",
54
+ # "descriptionModule",
55
+ # "conditions",
56
+ # "interventions",
57
+ # "outcomesModule",
58
+ # "eligibilityModule",
59
+ # ]
60
+
61
+ # # Iterate through the dictionary and keep only the desired fields
62
+ # filtered_data = []
63
+ # for item in data:
64
+ # try:
65
+ # organization_name = item["protocolSection"]["identificationModule"][
66
+ # "organization"
67
+ # ]["fullName"]
68
+ # except:
69
+ # organization_name = ""
70
+ # try:
71
+ # project_title = item["protocolSection"]["identificationModule"][
72
+ # "officialTitle"
73
+ # ]
74
+ # except:
75
+ # project_title = ""
76
+ # try:
77
+ # status = item["protocolSection"]["statusModule"]["overallStatus"]
78
+ # except:
79
+ # status = ""
80
+ # try:
81
+ # briefDescription = item["protocolSection"]["descriptionModule"][
82
+ # "briefSummary"
83
+ # ]
84
+ # except:
85
+ # briefDescription = ""
86
+ # try:
87
+ # detailedDescription = item["protocolSection"]["descriptionModule"][
88
+ # "detailedDescription"
89
+ # ]
90
+ # except:
91
+ # detailedDescription = ""
92
+ # try:
93
+ # conditions = item["protocolSection"]["conditionsModule"]["conditions"]
94
+ # except:
95
+ # conditions = []
96
+ # try:
97
+ # keywords = item["protocolSection"]["conditionsModule"]["keywords"]
98
+ # except:
99
+ # keywords = []
100
+ # try:
101
+ # interventions = item["protocolSection"]["armsInterventionsModule"][
102
+ # "interventions"
103
+ # ]
104
+ # except:
105
+ # interventions = []
106
+ # try:
107
+ # primary_outcomes = item["protocolSection"]["outcomesModule"][
108
+ # "primaryOutcomes"
109
+ # ]
110
+ # except:
111
+ # primary_outcomes = []
112
+ # try:
113
+ # secondary_outcomes = item["protocolSection"]["outcomesModule"][
114
+ # "secondaryOutcomes"
115
+ # ]
116
+ # except:
117
+ # secondary_outcomes = []
118
+ # try:
119
+ # eligibility = item["protocolSection"]["eligibilityModule"]
120
+ # except:
121
+ # eligibility = {}
122
+ # filtered_item = {
123
+ # "organization_name": organization_name,
124
+ # "project_title": project_title,
125
+ # "status": status,
126
+ # "briefDescription": briefDescription,
127
+ # "detailedDescription": detailedDescription,
128
+ # "keywords": keywords,
129
+ # "interventions": interventions,
130
+ # "primary_outcomes": primary_outcomes,
131
+ # "secondary_outcomes": secondary_outcomes,
132
+ # "eligibility": eligibility,
133
+ # }
134
+ # filtered_data.append(filtered_item)
135
+
136
+ # return filtered_data
137
+ # # for ele in filtered_data:
138
+ # # print(ele)
139
+
140
+ def process_dictionaty_with_llm_to_generate_response(json_data):
141
+ # processed_data = process_json_data_for_llm(json_data)
142
+ # res = tagging_chain.invoke({"input": processed_data})
143
+ # return res
144
  # Iterate through the dictionary and keep only the desired fields
145
  filtered_data = []
146
+ for item in json_data:
147
  try:
148
  organization_name = item["protocolSection"]["identificationModule"][
149
  "organization"
 
215
  "eligibility": eligibility,
216
  }
217
  filtered_data.append(filtered_item)
218
+
219
+ return filtered_data
 
 
220
 
221
  def get_short_summary_out_of_json_files(data_json):
222
+ # prompt_template = """ You are an expert clinician working on the analysis of reports of clinical trials.
223
+
224
+ # # Task
225
+ # You will be given a set of descriptions of clinical trials. Your job is to come up with a short summary (100-200 words) of the descriptions of the clinical trials. Your users are clinical researchers who are experts in medicine, so you should be technical and specific, including scientific terms. Always be faithful to the original information written in the reports.
226
 
227
+ # To write your summary, you will need to read the following examples, labeled as "Report 1", "Report 2", and so on. Your answer should be a single paragraph (100-200 words) that summarizes the general content of all the reports.
 
228
 
229
+ # {text}
230
 
231
+ # General summary:"""
232
 
233
+ prompt_template = """ You are an expert on clinicial trials and their analysis of their reports.
234
+ # Task
235
+ You will be given a text of descriptions of multiple clinical trials realed to similar diseases. Your job is to come up with a short and detailed summary of the descriptions of the clinical trials. Your users are clinical researchers, so you should be technical and specific, including scientific terms in the summary."""
236
 
237
  prompt = PromptTemplate.from_template(prompt_template)
238
 
 
263
  print(f"Combined descriptions: {combined_descriptions}")
264
 
265
  result = stuff_chain.run(combined_descriptions)
266
+ print(f"Result_summarization: {result}")
267
 
268
  return result
269
 
270
+ def tagging_insights_from_json(data_json):
271
+ processed_json= process_dictionaty_with_llm_to_generate_response(data_json)
272
+
273
+ tagging_prompt = ChatPromptTemplate.from_template(
274
+ """
275
+ You are an expert on clinicial trials and analysis of their reports.
276
+
277
+ Extract the desired information from the following JSON data.
278
 
279
+ Only extract the properties mentioned in the 'Classification' function.
280
+
281
+ JSON data:
282
+ {input}
283
+ """
284
+ )
285
  class Classification(BaseModel):
286
+ # description: str = Field(
287
+ # description="text description grouping all the clinical trials using briefDescription and detailedDescription keys"
288
+ # )
289
  project_title: list = Field(
290
+ description="Extract the project titles of all the clinical trials"
291
  )
292
  status: list = Field(
293
  description="Extract the status of all the clinical trials"
 
305
  # eligibility: list = Field(
306
  # description="get the eligibilityCriteria grouping all the clinical trials"
307
  # )
308
+ healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
309
+ minimum_age: list = Field(
310
+ description="get the minimum age from each experiment"
311
+ )
312
+ maximum_age: list = Field(
313
+ description="get the maximum age from each experiment"
314
+ )
315
+ gender: list = Field(description="get the gender from each experiment")
316
 
317
  def get_dict(self):
318
  return {
 
319
  "project_title": self.project_title,
320
  "status": self.status,
321
+ # "keywords": self.keywords,
322
  "interventions": self.interventions,
323
  "primary_outcomes": self.primary_outcomes,
324
  # "secondary_outcomes": self.secondary_outcomes,
325
+ # "eligibility": self.eligibility,
326
+ "healthy_volunteers": self.healthy_volunteers,
327
  "minimum_age": self.minimum_age,
328
  "maximum_age": self.maximum_age,
329
+ "gender": self.gender
330
  }
331
 
332
  # LLM
333
  llm = ChatOpenAI(
334
  temperature=0.6,
335
+ model="gpt-4-turbo",
336
  openai_api_key=os.environ["OPENAI_API_KEY"],
337
  ).with_structured_output(Classification)
338
 
339
+ # stuff_chain = StuffDocumentsChain(llm_chain=llm, document_variable_name="text")
340
 
341
+ tagging_chain = tagging_prompt | llm
342
 
343
+ res= tagging_chain.invoke({"input": processed_json})
344
+ result_dict= res.get_dict()
345
+ print(f"Result_tagging: {result_dict}")
346
+ return result_dict
347
 
348
 
349
  # clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
 
352
  # with open('data.json', 'w') as f:
353
  # json.dump(clinical_record_info, f, indent=4)
354
 
 
 
355
 
356
+ # tagging_chain = tagging_insights_from_json(json_data)