Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -13,6 +13,7 @@ import gradio as gr 
     | 
|
| 13 | 
         
             
            import torch
         
     | 
| 14 | 
         
             
            import matplotlib.pyplot as plt
         
     | 
| 15 | 
         
             
            from fpdf import FPDF
         
     | 
| 
         | 
|
| 16 | 
         | 
| 17 | 
         
             
            # === Configuration ===
         
     | 
| 18 | 
         
             
            persistent_dir = "/data/hf_cache"
         
     | 
| 
         @@ -223,17 +224,13 @@ Avoid repeating the same points multiple times. 
     | 
|
| 223 | 
         
             
                final_response = remove_duplicate_paragraphs(final_response)
         
     | 
| 224 | 
         
             
                return final_response
         
     | 
| 225 | 
         | 
| 226 | 
         
            -
            def  
     | 
| 227 | 
         
            -
                 
     | 
| 228 | 
         
            -
             
     | 
| 229 | 
         
            -
                def clean_for_pdf(text):
         
     | 
| 230 | 
         
            -
                    # Remove emojis and any non-latin characters
         
     | 
| 231 | 
         
            -
                    return ''.join(c for c in text if unicodedata.category(c)[0] != 'So')
         
     | 
| 232 | 
         | 
| 
         | 
|
| 233 | 
         
             
                chart_dir = os.path.join(os.path.dirname(report_path), "charts")
         
     | 
| 234 | 
         
             
                os.makedirs(chart_dir, exist_ok=True)
         
     | 
| 235 | 
         | 
| 236 | 
         
            -
                # Dummy chart
         
     | 
| 237 | 
         
             
                chart_path = os.path.join(chart_dir, "summary_chart.png")
         
     | 
| 238 | 
         
             
                categories = ['Diagnostics', 'Medications', 'Missed', 'Inconsistencies', 'Follow-up']
         
     | 
| 239 | 
         
             
                values = [4, 2, 3, 1, 5]
         
     | 
| 
         @@ -244,7 +241,6 @@ def generate_pdf_report_with_charts(summary: str, report_path: str): 
     | 
|
| 244 | 
         
             
                plt.savefig(chart_path)
         
     | 
| 245 | 
         
             
                plt.close()
         
     | 
| 246 | 
         | 
| 247 | 
         
            -
                # PDF report
         
     | 
| 248 | 
         
             
                pdf_path = report_path.replace('.md', '.pdf')
         
     | 
| 249 | 
         
             
                pdf = FPDF()
         
     | 
| 250 | 
         
             
                pdf.add_page()
         
     | 
| 
         @@ -253,7 +249,8 @@ def generate_pdf_report_with_charts(summary: str, report_path: str): 
     | 
|
| 253 | 
         
             
                pdf.ln(5)
         
     | 
| 254 | 
         | 
| 255 | 
         
             
                for line in summary.split("\n"):
         
     | 
| 256 | 
         
            -
                     
     | 
| 
         | 
|
| 257 | 
         | 
| 258 | 
         
             
                pdf.ln(10)
         
     | 
| 259 | 
         
             
                pdf.image(chart_path, w=150)
         
     | 
| 
         @@ -264,31 +261,43 @@ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Di 
     | 
|
| 264 | 
         
             
                if not file or not hasattr(file, "name"):
         
     | 
| 265 | 
         
             
                    messages.append({"role": "assistant", "content": "β Please upload a valid file."})
         
     | 
| 266 | 
         
             
                    return messages, None
         
     | 
| 
         | 
|
| 267 | 
         
             
                start_time = time.time()
         
     | 
| 268 | 
         
             
                messages.append({"role": "user", "content": f"π Processing file: {os.path.basename(file.name)}"})
         
     | 
| 
         | 
|
| 269 | 
         
             
                try:
         
     | 
| 270 | 
         
             
                    extracted = extract_text(file.name)
         
     | 
| 271 | 
         
             
                    if not extracted:
         
     | 
| 272 | 
         
             
                        messages.append({"role": "assistant", "content": "β Could not extract text."})
         
     | 
| 273 | 
         
             
                        return messages, None
         
     | 
| 
         | 
|
| 274 | 
         
             
                    chunks = split_text(extracted)
         
     | 
| 275 | 
         
             
                    batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
         
     | 
| 276 | 
         
             
                    messages.append({"role": "assistant", "content": f"π Split into {len(batches)} batches. Analyzing..."})
         
     | 
| 
         | 
|
| 277 | 
         
             
                    batch_results = analyze_batches(agent, batches)
         
     | 
| 278 | 
         
             
                    valid = [res for res in batch_results if not res.startswith("β")]
         
     | 
| 
         | 
|
| 279 | 
         
             
                    if not valid:
         
     | 
| 280 | 
         
             
                        messages.append({"role": "assistant", "content": "β No valid batch outputs."})
         
     | 
| 281 | 
         
             
                        return messages, None
         
     | 
| 
         | 
|
| 282 | 
         
             
                    summary = generate_final_summary(agent, "\n\n".join(valid))
         
     | 
| 
         | 
|
| 283 | 
         
             
                    report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
         
     | 
| 284 | 
         
             
                    with open(report_path, 'w', encoding='utf-8') as f:
         
     | 
| 285 | 
         
             
                        f.write(f"# Final Medical Report\n\n{summary}")
         
     | 
| 
         | 
|
| 286 | 
         
             
                    pdf_path = generate_pdf_report_with_charts(summary, report_path)
         
     | 
| 
         | 
|
| 287 | 
         
             
                    end_time = time.time()
         
     | 
| 288 | 
         
             
                    elapsed_time = end_time - start_time
         
     | 
| 
         | 
|
| 289 | 
         
             
                    messages.append({"role": "assistant", "content": f"π **Final Report:**\n\n{summary}"})
         
     | 
| 290 | 
         
             
                    messages.append({"role": "assistant", "content": f"β
 Report generated in **{elapsed_time:.2f} seconds**.\n\nπ₯ PDF report ready: {os.path.basename(pdf_path)}"})
         
     | 
| 
         | 
|
| 291 | 
         
             
                    return messages, pdf_path
         
     | 
| 
         | 
|
| 292 | 
         
             
                except Exception as e:
         
     | 
| 293 | 
         
             
                    messages.append({"role": "assistant", "content": f"β Error: {str(e)}"})
         
     | 
| 294 | 
         
             
                    return messages, None
         
     | 
| 
         @@ -302,22 +311,27 @@ def create_ui(agent): 
     | 
|
| 302 | 
         
             
                    .gr-file, .gr-button { width: 100% !important; max-width: 400px; }
         
     | 
| 303 | 
         
             
                """) as demo:
         
     | 
| 304 | 
         
             
                    gr.Markdown("""
         
     | 
| 305 | 
         
            -
                    <h2 style= 
     | 
| 306 | 
         
            -
                    <p style= 
     | 
| 307 | 
         
             
                    """)
         
     | 
| 
         | 
|
| 308 | 
         
             
                    with gr.Column():
         
     | 
| 309 | 
         
             
                        chatbot = gr.Chatbot(label="π§  CPS Assistant", height=480, type="messages")
         
     | 
| 310 | 
         
             
                        upload = gr.File(label="π Upload Medical File", file_types=[".xlsx", ".csv", ".pdf"])
         
     | 
| 311 | 
         
             
                        analyze = gr.Button("π§  Analyze")
         
     | 
| 312 | 
         
             
                        download = gr.File(label="π₯ Download Report", visible=False, interactive=False)
         
     | 
| 
         | 
|
| 313 | 
         
             
                    state = gr.State(value=[])
         
     | 
| 
         | 
|
| 314 | 
         
             
                    def handle_analysis(file, chat):
         
     | 
| 315 | 
         
             
                        messages, report_path = process_report(agent, file, chat)
         
     | 
| 316 | 
         
             
                        return messages, gr.update(visible=bool(report_path), value=report_path), messages
         
     | 
| 
         | 
|
| 317 | 
         
             
                    analyze.click(fn=handle_analysis, inputs=[upload, state], outputs=[chatbot, download, state])
         
     | 
| 
         | 
|
| 318 | 
         
             
                return demo
         
     | 
| 319 | 
         | 
| 320 | 
         
             
            if __name__ == "__main__":
         
     | 
| 321 | 
         
             
                agent = init_agent()
         
     | 
| 322 | 
         
             
                ui = create_ui(agent)
         
     | 
| 323 | 
         
            -
                ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
         
     | 
| 
         | 
|
| 13 | 
         
             
            import torch
         
     | 
| 14 | 
         
             
            import matplotlib.pyplot as plt
         
     | 
| 15 | 
         
             
            from fpdf import FPDF
         
     | 
| 16 | 
         
            +
            import unicodedata
         
     | 
| 17 | 
         | 
| 18 | 
         
             
            # === Configuration ===
         
     | 
| 19 | 
         
             
            persistent_dir = "/data/hf_cache"
         
     | 
| 
         | 
|
| 224 | 
         
             
                final_response = remove_duplicate_paragraphs(final_response)
         
     | 
| 225 | 
         
             
                return final_response
         
     | 
| 226 | 
         | 
| 227 | 
         
            +
            def remove_non_ascii(text):
         
     | 
| 228 | 
         
            +
                return unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode('ascii')
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 229 | 
         | 
| 230 | 
         
            +
            def generate_pdf_report_with_charts(summary: str, report_path: str):
         
     | 
| 231 | 
         
             
                chart_dir = os.path.join(os.path.dirname(report_path), "charts")
         
     | 
| 232 | 
         
             
                os.makedirs(chart_dir, exist_ok=True)
         
     | 
| 233 | 
         | 
| 
         | 
|
| 234 | 
         
             
                chart_path = os.path.join(chart_dir, "summary_chart.png")
         
     | 
| 235 | 
         
             
                categories = ['Diagnostics', 'Medications', 'Missed', 'Inconsistencies', 'Follow-up']
         
     | 
| 236 | 
         
             
                values = [4, 2, 3, 1, 5]
         
     | 
| 
         | 
|
| 241 | 
         
             
                plt.savefig(chart_path)
         
     | 
| 242 | 
         
             
                plt.close()
         
     | 
| 243 | 
         | 
| 
         | 
|
| 244 | 
         
             
                pdf_path = report_path.replace('.md', '.pdf')
         
     | 
| 245 | 
         
             
                pdf = FPDF()
         
     | 
| 246 | 
         
             
                pdf.add_page()
         
     | 
| 
         | 
|
| 249 | 
         
             
                pdf.ln(5)
         
     | 
| 250 | 
         | 
| 251 | 
         
             
                for line in summary.split("\n"):
         
     | 
| 252 | 
         
            +
                    clean_line = remove_non_ascii(line)
         
     | 
| 253 | 
         
            +
                    pdf.multi_cell(0, 10, txt=clean_line)
         
     | 
| 254 | 
         | 
| 255 | 
         
             
                pdf.ln(10)
         
     | 
| 256 | 
         
             
                pdf.image(chart_path, w=150)
         
     | 
| 
         | 
|
| 261 | 
         
             
                if not file or not hasattr(file, "name"):
         
     | 
| 262 | 
         
             
                    messages.append({"role": "assistant", "content": "β Please upload a valid file."})
         
     | 
| 263 | 
         
             
                    return messages, None
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
             
                start_time = time.time()
         
     | 
| 266 | 
         
             
                messages.append({"role": "user", "content": f"π Processing file: {os.path.basename(file.name)}"})
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
             
                try:
         
     | 
| 269 | 
         
             
                    extracted = extract_text(file.name)
         
     | 
| 270 | 
         
             
                    if not extracted:
         
     | 
| 271 | 
         
             
                        messages.append({"role": "assistant", "content": "β Could not extract text."})
         
     | 
| 272 | 
         
             
                        return messages, None
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
             
                    chunks = split_text(extracted)
         
     | 
| 275 | 
         
             
                    batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
         
     | 
| 276 | 
         
             
                    messages.append({"role": "assistant", "content": f"π Split into {len(batches)} batches. Analyzing..."})
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
             
                    batch_results = analyze_batches(agent, batches)
         
     | 
| 279 | 
         
             
                    valid = [res for res in batch_results if not res.startswith("β")]
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
             
                    if not valid:
         
     | 
| 282 | 
         
             
                        messages.append({"role": "assistant", "content": "β No valid batch outputs."})
         
     | 
| 283 | 
         
             
                        return messages, None
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
             
                    summary = generate_final_summary(agent, "\n\n".join(valid))
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
             
                    report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
         
     | 
| 288 | 
         
             
                    with open(report_path, 'w', encoding='utf-8') as f:
         
     | 
| 289 | 
         
             
                        f.write(f"# Final Medical Report\n\n{summary}")
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
             
                    pdf_path = generate_pdf_report_with_charts(summary, report_path)
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
             
                    end_time = time.time()
         
     | 
| 294 | 
         
             
                    elapsed_time = end_time - start_time
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
             
                    messages.append({"role": "assistant", "content": f"π **Final Report:**\n\n{summary}"})
         
     | 
| 297 | 
         
             
                    messages.append({"role": "assistant", "content": f"β
 Report generated in **{elapsed_time:.2f} seconds**.\n\nπ₯ PDF report ready: {os.path.basename(pdf_path)}"})
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
             
                    return messages, pdf_path
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
             
                except Exception as e:
         
     | 
| 302 | 
         
             
                    messages.append({"role": "assistant", "content": f"β Error: {str(e)}"})
         
     | 
| 303 | 
         
             
                    return messages, None
         
     | 
| 
         | 
|
| 311 | 
         
             
                    .gr-file, .gr-button { width: 100% !important; max-width: 400px; }
         
     | 
| 312 | 
         
             
                """) as demo:
         
     | 
| 313 | 
         
             
                    gr.Markdown("""
         
     | 
| 314 | 
         
            +
                    <h2 style='text-align:center;'>π CPS: Clinical Patient Support System</h2>
         
     | 
| 315 | 
         
            +
                    <p style='text-align:center;'>Analyze and summarize unstructured medical files using AI (optimized for A100 GPU).</p>
         
     | 
| 316 | 
         
             
                    """)
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
             
                    with gr.Column():
         
     | 
| 319 | 
         
             
                        chatbot = gr.Chatbot(label="π§  CPS Assistant", height=480, type="messages")
         
     | 
| 320 | 
         
             
                        upload = gr.File(label="π Upload Medical File", file_types=[".xlsx", ".csv", ".pdf"])
         
     | 
| 321 | 
         
             
                        analyze = gr.Button("π§  Analyze")
         
     | 
| 322 | 
         
             
                        download = gr.File(label="π₯ Download Report", visible=False, interactive=False)
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
             
                    state = gr.State(value=[])
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
             
                    def handle_analysis(file, chat):
         
     | 
| 327 | 
         
             
                        messages, report_path = process_report(agent, file, chat)
         
     | 
| 328 | 
         
             
                        return messages, gr.update(visible=bool(report_path), value=report_path), messages
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
             
                    analyze.click(fn=handle_analysis, inputs=[upload, state], outputs=[chatbot, download, state])
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
             
                return demo
         
     | 
| 333 | 
         | 
| 334 | 
         
             
            if __name__ == "__main__":
         
     | 
| 335 | 
         
             
                agent = init_agent()
         
     | 
| 336 | 
         
             
                ui = create_ui(agent)
         
     | 
| 337 | 
         
            +
                ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
         
     |