Kevin Wu commited on
Commit
95174f7
·
1 Parent(s): 854997c
Files changed (2) hide show
  1. app.py +179 -156
  2. requirements.txt +1 -1
app.py CHANGED
@@ -4,184 +4,203 @@ import os
4
  import time
5
  import gradio as gr
6
  from openai import OpenAI
7
-
8
  import xml.etree.ElementTree as ET
9
  import re
10
  import pandas as pd
11
-
12
  import prompts
 
13
 
14
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
15
 
16
  model_name = "gpt-4o-2024-08-06"
17
 
18
- demo = client.beta.assistants.create(
19
- name="Information Extractor",
20
- instructions="Extract information from this note.",
21
- model=model_name,
22
- tools=[{"type": "file_search"}],
23
- )
 
 
 
 
24
 
25
  def parse_xml_response(xml_string: str) -> pd.DataFrame:
26
  """
27
  Parse the XML response from the model and extract all fields into a dictionary,
28
  then convert it to a pandas DataFrame with a nested index.
29
  """
30
- # Extract only the XML content between the first and last tags
31
- xml_content = re.search(r'<.*?>.*</.*?>', xml_string, re.DOTALL)
32
- if xml_content:
33
- xml_string = xml_content.group(0)
34
- else:
35
- print("No valid XML content found.")
36
- return pd.DataFrame()
37
-
38
  try:
 
 
 
 
 
 
 
 
39
  root = ET.fromstring(xml_string)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  except ET.ParseError as e:
41
- print(f"Error parsing XML: {e}")
 
 
 
 
 
42
  return pd.DataFrame()
43
-
44
- result = {}
45
-
46
- for element in root:
47
- tag = element.tag
48
- if tag in ['patient_name', 'date_of_birth', 'sex', 'weight', 'date_of_death']:
49
- result[tag] = {
50
- 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None,
51
- **{child.tag: child.text.strip() if child.text else None
52
- for child in element if child.tag != 'reasoning'}
53
- }
54
- elif tag in ['traditional_chemo', 'other_cancer_treatments', 'other_conmeds']:
55
- if tag not in result:
56
- result[tag] = []
57
- reasoning = element.find('reasoning')
58
- for item in element:
59
- if item.tag in ['drug', 'treatment', 'medication']:
60
- date_element = element.find('date')
61
- result[tag].append({
62
- 'reasoning': reasoning.text.strip() if reasoning is not None else None,
63
- 'name': item.text.strip() if item.text else None,
64
- 'date': date_element.text.strip() if date_element is not None and date_element.text else None
65
- })
66
- elif tag in ['surgery', 'surgery_outcome', 'metastasis_at_time_of_diagnosis']:
67
- result[tag] = {
68
- 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None,
69
- **{child.tag: child.text.strip() if child.text else None
70
- for child in element if child.tag != 'reasoning'}
71
- }
72
- elif tag == 'compounding_pharmacy':
73
- result[tag] = {
74
- 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None,
75
- 'pharmacy': element.find('pharmacy').text.strip() if element.find('pharmacy') is not None else None
76
- }
77
- elif tag == 'adverse_effects':
78
- if tag not in result:
79
- result[tag] = []
80
- effect = {
81
- 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None
82
- }
83
- for child in element:
84
- if child.tag != 'reasoning':
85
- effect[child.tag] = child.text.strip() if child.text else None
86
- if effect:
87
- result[tag].append(effect)
88
-
89
- # Convert to nested DataFrame
90
- df_data = {}
91
- for key, value in result.items():
92
- if isinstance(value, dict):
93
- for sub_key, sub_value in value.items():
94
- df_data[(key, '1', sub_key)] = [sub_value]
95
- elif isinstance(value, list):
96
- for i, item in enumerate(value):
97
- for sub_key, sub_value in item.items():
98
- df_data[(key, f"{i+1}", sub_key)] = [sub_value]
99
- else:
100
- df_data[(key, '1', '')] = [value]
101
-
102
- # Create multi-index DataFrame
103
- df = pd.DataFrame(df_data)
104
- df.columns = pd.MultiIndex.from_tuples(df.columns)
105
-
106
- return df
107
 
108
  def get_response(prompt, file_id, assistant_id):
109
- thread = client.beta.threads.create(
110
- messages=[
111
- {
112
- "role": "user",
113
- "content": prompts.info_prompt,
114
- "attachments": [
115
- {"file_id": file_id, "tools": [{"type": "file_search"}]}
116
- ],
117
- }
118
- ]
119
- )
120
- run = client.beta.threads.runs.create_and_poll(
121
- thread_id=thread.id, assistant_id=assistant_id
122
- )
123
- messages = list(
124
- client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id)
125
- )
126
-
127
- message_content = messages[0].content[0].text
128
- annotations = message_content.annotations
129
- for index, annotation in enumerate(annotations):
130
- message_content.value = message_content.value.replace(annotation.text, f"")
131
- return message_content.value
 
 
 
 
 
132
 
133
  def process(file_content):
134
- if not os.path.exists("cache"):
135
- os.makedirs("cache")
136
- file_name = f"cache/{time.time()}.pdf"
137
- with open(file_name, "wb") as f:
138
- f.write(file_content)
139
-
140
- message_file = client.files.create(file=open(file_name, "rb"), purpose="assistants")
141
-
142
- response = get_response(prompts.info_prompt, message_file.id, demo.id)
143
- df = parse_xml_response(response)
144
-
145
- if df.empty:
146
- return "<p>No valid information could be extracted from the provided file.</p>"
147
-
148
- # Transpose the DataFrame
149
- df_transposed = df.T.reset_index()
150
- df_transposed.columns = ['Category', 'Index', 'Field', 'Value']
151
- df_transposed = df_transposed.sort_values(['Category', 'Index', 'Field'])
152
-
153
- # Convert to HTML with some basic styling
154
- html = df_transposed.to_html(index=False, classes='table table-striped table-bordered', escape=False)
155
-
156
- # Add some custom CSS for better readability
157
- html = f"""
158
- <style>
159
- .table {{
160
- width: 100%;
161
- max-width: 100%;
162
- margin-bottom: 1rem;
163
- background-color: transparent;
164
- }}
165
- .table td, .table th {{
166
- padding: .75rem;
167
- vertical-align: top;
168
- border-top: 1px solid #dee2e6;
169
- }}
170
- .table thead th {{
171
- vertical-align: bottom;
172
- border-bottom: 2px solid #dee2e6;
173
- }}
174
- .table tbody + tbody {{
175
- border-top: 2px solid #dee2e6;
176
- }}
177
- .table-striped tbody tr:nth-of-type(odd) {{
178
- background-color: rgba(0,0,0,.05);
179
- }}
180
- </style>
181
- {html}
182
- """
183
-
184
- return html
 
 
 
 
 
 
185
 
186
  def gradio_interface():
187
  upload_component = gr.File(label="Upload PDF", type="binary")
@@ -198,4 +217,8 @@ def gradio_interface():
198
  demo.launch()
199
 
200
  if __name__ == "__main__":
201
- gradio_interface()
 
 
 
 
 
4
  import time
5
  import gradio as gr
6
  from openai import OpenAI
 
7
  import xml.etree.ElementTree as ET
8
  import re
9
  import pandas as pd
 
10
  import prompts
11
+ import traceback
12
 
13
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
14
 
15
  model_name = "gpt-4o-2024-08-06"
16
 
17
+ try:
18
+ demo = client.beta.assistants.create(
19
+ name="Information Extractor",
20
+ instructions="Extract information from this note.",
21
+ model=model_name,
22
+ tools=[{"type": "file_search"}],
23
+ )
24
+ except Exception as e:
25
+ print(f"Error creating assistant: {str(e)}")
26
+ raise
27
 
28
  def parse_xml_response(xml_string: str) -> pd.DataFrame:
29
  """
30
  Parse the XML response from the model and extract all fields into a dictionary,
31
  then convert it to a pandas DataFrame with a nested index.
32
  """
 
 
 
 
 
 
 
 
33
  try:
34
+ # Extract only the XML content between the first and last tags
35
+ xml_content = re.search(r'<.*?>.*</.*?>', xml_string, re.DOTALL)
36
+ if xml_content:
37
+ xml_string = xml_content.group(0)
38
+ else:
39
+ print("No valid XML content found.")
40
+ return pd.DataFrame()
41
+
42
  root = ET.fromstring(xml_string)
43
+
44
+ result = {}
45
+
46
+ for element in root:
47
+ tag = element.tag
48
+ if tag in ['patient_name', 'date_of_birth', 'sex', 'weight', 'date_of_death']:
49
+ result[tag] = {
50
+ 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None,
51
+ **{child.tag: child.text.strip() if child.text else None
52
+ for child in element if child.tag != 'reasoning'}
53
+ }
54
+ elif tag in ['traditional_chemo', 'other_cancer_treatments', 'other_conmeds']:
55
+ if tag not in result:
56
+ result[tag] = []
57
+ reasoning = element.find('reasoning')
58
+ for item in element:
59
+ if item.tag in ['drug', 'treatment', 'medication']:
60
+ date_element = element.find('date')
61
+ result[tag].append({
62
+ 'reasoning': reasoning.text.strip() if reasoning is not None else None,
63
+ 'name': item.text.strip() if item.text else None,
64
+ 'date': date_element.text.strip() if date_element is not None and date_element.text else None
65
+ })
66
+ elif tag in ['surgery', 'surgery_outcome', 'metastasis_at_time_of_diagnosis']:
67
+ result[tag] = {
68
+ 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None,
69
+ **{child.tag: child.text.strip() if child.text else None
70
+ for child in element if child.tag != 'reasoning'}
71
+ }
72
+ elif tag == 'compounding_pharmacy':
73
+ result[tag] = {
74
+ 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None,
75
+ 'pharmacy': element.find('pharmacy').text.strip() if element.find('pharmacy') is not None else None
76
+ }
77
+ elif tag == 'adverse_effects':
78
+ if tag not in result:
79
+ result[tag] = []
80
+ effect = {
81
+ 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None
82
+ }
83
+ for child in element:
84
+ if child.tag != 'reasoning':
85
+ effect[child.tag] = child.text.strip() if child.text else None
86
+ if effect:
87
+ result[tag].append(effect)
88
+
89
+ # Convert to nested DataFrame
90
+ df_data = {}
91
+ for key, value in result.items():
92
+ if isinstance(value, dict):
93
+ for sub_key, sub_value in value.items():
94
+ df_data[(key, '1', sub_key)] = [sub_value]
95
+ elif isinstance(value, list):
96
+ for i, item in enumerate(value):
97
+ for sub_key, sub_value in item.items():
98
+ df_data[(key, f"{i+1}", sub_key)] = [sub_value]
99
+ else:
100
+ df_data[(key, '1', '')] = [value]
101
+
102
+ # Create multi-index DataFrame
103
+ df = pd.DataFrame(df_data)
104
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
105
+
106
+ return df
107
  except ET.ParseError as e:
108
+ print(f"XML parsing error: {str(e)}")
109
+ print(f"Problematic XML content: {xml_string[:500]}...") # Print first 500 chars of XML
110
+ return pd.DataFrame()
111
+ except Exception as e:
112
+ print(f"Error in parse_xml_response: {str(e)}")
113
+ print(f"Traceback: {traceback.format_exc()}")
114
  return pd.DataFrame()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  def get_response(prompt, file_id, assistant_id):
117
+ try:
118
+ thread = client.beta.threads.create(
119
+ messages=[
120
+ {
121
+ "role": "user",
122
+ "content": prompts.info_prompt,
123
+ "attachments": [
124
+ {"file_id": file_id, "tools": [{"type": "file_search"}]}
125
+ ],
126
+ }
127
+ ]
128
+ )
129
+ run = client.beta.threads.runs.create_and_poll(
130
+ thread_id=thread.id, assistant_id=assistant_id
131
+ )
132
+ messages = list(
133
+ client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id)
134
+ )
135
+ assert len(messages) == 1, f"Expected 1 message, got {len(messages)}"
136
+ message_content = messages[0].content[0].text
137
+ annotations = message_content.annotations
138
+ for index, annotation in enumerate(annotations):
139
+ message_content.value = message_content.value.replace(annotation.text, f"")
140
+ return message_content.value
141
+ except Exception as e:
142
+ print(f"Error in get_response: {str(e)}")
143
+ print(f"Traceback: {traceback.format_exc()}")
144
+ raise
145
 
146
  def process(file_content):
147
+ try:
148
+ if not os.path.exists("cache"):
149
+ os.makedirs("cache")
150
+ file_name = f"cache/{time.time()}.pdf"
151
+ with open(file_name, "wb") as f:
152
+ f.write(file_content)
153
+
154
+ message_file = client.files.create(file=open(file_name, "rb"), purpose="assistants")
155
+
156
+ response = get_response(prompts.info_prompt, message_file.id, demo.id)
157
+ df = parse_xml_response(response)
158
+
159
+ if df.empty:
160
+ return "<p>No valid information could be extracted from the provided file.</p>"
161
+
162
+ # Transpose the DataFrame
163
+ df_transposed = df.T.reset_index()
164
+ df_transposed.columns = ['Category', 'Index', 'Field', 'Value']
165
+ df_transposed = df_transposed.sort_values(['Category', 'Index', 'Field'])
166
+
167
+ # Convert to HTML with some basic styling
168
+ html = df_transposed.to_html(index=False, classes='table table-striped table-bordered', escape=False)
169
+
170
+ # Add some custom CSS for better readability
171
+ html = f"""
172
+ <style>
173
+ .table {{
174
+ width: 100%;
175
+ max-width: 100%;
176
+ margin-bottom: 1rem;
177
+ background-color: transparent;
178
+ }}
179
+ .table td, .table th {{
180
+ padding: .75rem;
181
+ vertical-align: top;
182
+ border-top: 1px solid #dee2e6;
183
+ }}
184
+ .table thead th {{
185
+ vertical-align: bottom;
186
+ border-bottom: 2px solid #dee2e6;
187
+ }}
188
+ .table tbody + tbody {{
189
+ border-top: 2px solid #dee2e6;
190
+ }}
191
+ .table-striped tbody tr:nth-of-type(odd) {{
192
+ background-color: rgba(0,0,0,.05);
193
+ }}
194
+ </style>
195
+ {html}
196
+ """
197
+
198
+ return html
199
+ except Exception as e:
200
+ error_message = f"An error occurred while processing the file: {str(e)}"
201
+ print(error_message)
202
+ print(f"Traceback: {traceback.format_exc()}")
203
+ return f"<p>{error_message}</p>"
204
 
205
  def gradio_interface():
206
  upload_component = gr.File(label="Upload PDF", type="binary")
 
217
  demo.launch()
218
 
219
  if __name__ == "__main__":
220
+ try:
221
+ gradio_interface()
222
+ except Exception as e:
223
+ print(f"Error launching Gradio interface: {str(e)}")
224
+ print(f"Traceback: {traceback.format_exc()}")
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- gradio==3.50.2
2
  openai==1.51.2
3
  pandas
 
1
+ gradio==4.29.0
2
  openai==1.51.2
3
  pandas