Spaces:
Sleeping
Sleeping
| import glob | |
| import json | |
| import os | |
| import time | |
| import gradio as gr | |
| from openai import OpenAI | |
| import prompts | |
| import traceback | |
| from io import StringIO | |
| import pandas as pd | |
| from typing import Dict, Any | |
| from typing import List, Optional | |
| from pydantic import BaseModel, Field | |
| from structures import ClinicalInfo | |
| client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
| model_name = "gpt-4o-2024-08-06" | |
| # import pdb; pdb.set_trace() | |
| try: | |
| demo = client.beta.assistants.create( | |
| name="Information Extractor", | |
| instructions="Extract information from this note and return it as a JSON object.", | |
| model=model_name, | |
| tools=[{"type": "file_search"}], | |
| ) | |
| except Exception as e: | |
| print(f"Error creating assistant: {str(e)}") | |
| raise | |
| def parse_response(prompt): | |
| chat_completion = client.beta.chat.completions.parse( | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": prompt, | |
| } | |
| ], | |
| model=model_name, | |
| response_format=ClinicalInfo, | |
| ) | |
| return chat_completion.choices[0].message.parsed.model_dump() | |
| def get_response(file_id, assistant_id, max_retries=3): | |
| for attempt in range(max_retries): | |
| try: | |
| thread = client.beta.threads.create( | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": prompts.info_prompt, | |
| "attachments": [ | |
| {"file_id": file_id, "tools": [{"type": "file_search"}]} | |
| ],} | |
| ] | |
| ) | |
| # import pdb; pdb.set_trace() | |
| run = client.beta.threads.runs.create( | |
| thread_id=thread.id, | |
| assistant_id=assistant_id, | |
| instructions="Please provide your response as a valid JSON object.", | |
| ) | |
| run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) | |
| while run.status != "completed": | |
| time.sleep(1) | |
| run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) | |
| messages = list( | |
| client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id) | |
| ) | |
| assert len(messages) == 1, f"Expected 1 message, got {len(messages)}" | |
| message_content = messages[0].content[0].text | |
| annotations = message_content.annotations | |
| for index, annotation in enumerate(annotations): | |
| message_content.value = message_content.value.replace(annotation.text, f"") | |
| return message_content.value | |
| except Exception as e: | |
| print(f"Error in get_response (attempt {attempt + 1}): {str(e)}") | |
| print(f"Traceback: {traceback.format_exc()}") | |
| if attempt < max_retries - 1: | |
| print(f"Retrying in 5 seconds...") | |
| time.sleep(5) | |
| else: | |
| raise Exception("Max retries reached. Unable to get response from the model.") | |
| def clinical_info_to_dataframe(clinical_info: Dict[str, Any]) -> pd.DataFrame: | |
| """ | |
| Convert ClinicalInfo dictionary to a DataFrame. | |
| """ | |
| data = [] | |
| for field, value in clinical_info.items(): | |
| if isinstance(value, dict): | |
| for sub_field, sub_value in value.items(): | |
| data.append({ | |
| 'Category': field, | |
| 'Field': sub_field, | |
| 'Value': str(sub_value) | |
| }) | |
| elif isinstance(value, list): | |
| for i, item in enumerate(value): | |
| for sub_field, sub_value in item.items(): | |
| data.append({ | |
| 'Category': f"{field}_{i+1}", | |
| 'Field': sub_field, | |
| 'Value': str(sub_value) | |
| }) | |
| elif value is None: | |
| data.append({ | |
| 'Category': field, | |
| 'Field': 'value', | |
| 'Value': 'None' | |
| }) | |
| return pd.DataFrame(data) | |
| def process(file_content): | |
| try: | |
| if not os.path.exists("cache"): | |
| os.makedirs("cache") | |
| file_name = f"cache/{time.time()}.pdf" | |
| with open(file_name, "wb") as f: | |
| f.write(file_content) | |
| message_file = client.files.create(file=open(file_name, "rb"), purpose="assistants") | |
| response = get_response(message_file.id, demo.id) # This now includes retry logic | |
| response_prompt = f"Please parse the following response into the correct format: {response}" | |
| clinical_info = parse_response(response_prompt) | |
| df = clinical_info_to_dataframe(clinical_info) | |
| if df.empty: | |
| return "<p>No valid information could be extracted from the provided file.</p>" | |
| # Sort the DataFrame | |
| df = df.sort_values(['Category', 'Field']) | |
| # Convert to HTML with some basic styling | |
| html = df.to_html(index=False, classes='table table-striped table-bordered', escape=False) | |
| # Add some custom CSS for better readability | |
| html = f""" | |
| <style> | |
| .table {{ | |
| width: 100%; | |
| max-width: 100%; | |
| margin-bottom: 1rem; | |
| background-color: transparent; | |
| }} | |
| .table td, .table th {{ | |
| padding: .75rem; | |
| vertical-align: top; | |
| border-top: 1px solid #dee2e6; | |
| }} | |
| .table thead th {{ | |
| vertical-align: bottom; | |
| border-bottom: 2px solid #dee2e6; | |
| }} | |
| .table tbody + tbody {{ | |
| border-top: 2px solid #dee2e6; | |
| }} | |
| .table-striped tbody tr:nth-of-type(odd) {{ | |
| background-color: rgba(0,0,0,.05); | |
| }} | |
| </style> | |
| {html} | |
| """ | |
| return html | |
| except Exception as e: | |
| error_message = f"An error occurred while processing the file: {str(e)}" | |
| print(error_message) | |
| print(f"Traceback: {traceback.format_exc()}") | |
| return f"<p>{error_message}</p>" | |
| def gradio_interface(): | |
| upload_component = gr.File(label="Upload PDF", type="binary") | |
| output_component = gr.HTML(label="Extracted Information") | |
| demo = gr.Interface( | |
| fn=process, | |
| inputs=upload_component, | |
| outputs=output_component, | |
| title="Clinical Note Information Extractor", | |
| description="This tool extracts key information from clinical notes in PDF format.", | |
| ) | |
| demo.queue() | |
| demo.launch() | |
| def run_in_terminal(): | |
| print("Clinical Note Information Extractor") | |
| print("This tool extracts key information from clinical notes in PDF format.") | |
| file_path = "../clinicalnotes_raw/0b7wtxiunxwploe6tnnluh0l84qg.pdf" | |
| if not os.path.exists(file_path): | |
| print(f"Error: File not found at {file_path}") | |
| return | |
| try: | |
| with open(file_path, "rb") as file: | |
| file_content = file.read() | |
| result = process(file_content) | |
| if result.startswith("<p>"): | |
| # Error message | |
| print(result[3:-4]) # Remove <p> tags | |
| else: | |
| # Save the HTML output to a file | |
| output_file = f"output_{time.time()}.html" | |
| with open(output_file, "w", encoding="utf-8") as f: | |
| f.write(result) | |
| print(f"Extraction completed. Results saved to {output_file}") | |
| # Also print a simplified version to the console | |
| df = pd.read_html(result)[0] | |
| print("\nExtracted Information:") | |
| for _, row in df.iterrows(): | |
| print(f"{row['Category']} - {row['Field']}: {row['Value']}") | |
| except Exception as e: | |
| print(f"An error occurred while processing the file: {str(e)}") | |
| print(f"Traceback: {traceback.format_exc()}") | |
| if __name__ == "__main__": | |
| try: | |
| gradio_interface() | |
| # run_in_terminal() | |
| except Exception as e: | |
| print(f"Error launching Gradio interface: {str(e)}") | |
| print(f"Traceback: {traceback.format_exc()}") |