import time
import gradio as gr
import torch
import numpy as np
from PIL import Image
import time
import io
import subprocess
import sys
# Install required packages
def install_packages():
    packages = [
        "transformers",
        "accelerate", 
        "timm",
        "easyocr"
    ]
    for package in packages:
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])
        except:
            print(f"Warning: Could not install {package}")
# Install packages at startup
install_packages()
from transformers import AutoProcessor, AutoModelForImageTextToText, AutoConfig
# Global variables for model
processor = None
model = None
config = None
ocr_reader = None
def load_model():
    """Load the Gemma 3n model"""
    global processor, model, config, ocr_reader
    
    try:
        print("🚀 Loading Gemma 3n model...")
        GEMMA_PATH = "google/gemma-3n-e2b-it"
        
        # Load configuration
        config = AutoConfig.from_pretrained(GEMMA_PATH, trust_remote_code=True)
        print("✅ Config loaded")
        
        # Load processor
        processor = AutoProcessor.from_pretrained(GEMMA_PATH, trust_remote_code=True)
        print("✅ Processor loaded")
        
        # Load model
        model = AutoModelForImageTextToText.from_pretrained(
            GEMMA_PATH,
            config=config,
            torch_dtype="auto",
            device_map="auto",
            trust_remote_code=True
        )
        print("✅ Model loaded successfully!")
        
        # Set up compilation fix
        import torch._dynamo
        torch._dynamo.config.suppress_errors = True
        
        # Initialize OCR
        try:
            import easyocr
            ocr_reader = easyocr.Reader(['en'], gpu=False, verbose=False)
            print("✅ EasyOCR initialized")
        except Exception as e:
            print(f"⚠️ EasyOCR not available: {e}")
            ocr_reader = None
            
        return True
        
    except Exception as e:
        print(f"❌ Model loading failed: {e}")
        return False
def generate_soap_note(text):
    """Generate SOAP note using Gemma 3n"""
    if model is None or processor is None:
        return "❌ Model not loaded. Please wait for initialization."
    
    soap_prompt = f"""You are a medical AI assistant. Convert the following medical notes into a properly formatted SOAP note.
Medical notes:
{text}
Please format as:
S - SUBJECTIVE: (chief complaint, history of present illness, past medical history, medications, allergies)
O - OBJECTIVE: (vital signs, physical examination findings)  
A - ASSESSMENT: (diagnosis/clinical impression)
P - PLAN: (treatment plan, follow-up instructions)
Generate a complete, professional SOAP note:"""
    
    messages = [{
        "role": "system",
        "content": [{"type": "text", "text": "You are an expert medical AI assistant specialized in creating SOAP notes from medical documentation."}]
    }, {
        "role": "user", 
        "content": [{"type": "text", "text": soap_prompt}]
    }]
    
    try:
        inputs = processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt"
        ).to(model.device)
        
        input_len = inputs["input_ids"].shape[-1]
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=400,
                do_sample=True,
                temperature=0.1,
                top_p=0.95,
                pad_token_id=processor.tokenizer.eos_token_id,
                disable_compile=True
            )
        
        response = processor.batch_decode(
            outputs[:, input_len:],
            skip_special_tokens=True
        )[0].strip()
        
        return response
        
    except Exception as e:
        return f"❌ SOAP generation failed: {str(e)}"
def extract_text_from_image(image):
    """Extract text using EasyOCR - fast processing"""
    if ocr_reader is None:
        return "❌ OCR not available"
    
    try:
        if hasattr(image, 'convert'):
            image = image.convert('RGB')
        img_array = np.array(image)
        
        results = ocr_reader.readtext(img_array, detail=0, paragraph=True)
        if results:
            return ' '.join(results).strip()
        else:
            return "❌ No text detected in image"
            
    except Exception as e:
        return f"❌ OCR failed: {str(e)}"
def process_medical_input(image, text):
    """Main processing function for the Gradio interface"""
    
    if image is not None and text.strip():
        return "⚠️ Please provide either an image OR text, not both.", ""
    
    if image is not None:
        # Process image
        print("🔍 Extracting text from image...")
        extracted_text = extract_text_from_image(image)
        
        if extracted_text.startswith('❌'):
            return extracted_text, ""
        
        print("🤖 Generating SOAP note...")
        soap_note = generate_soap_note(extracted_text)
        
        return extracted_text, soap_note
        
    elif text.strip():
        # Process text directly
        print("🤖 Generating SOAP note from text...")
        soap_note = generate_soap_note(text.strip())
        return text.strip(), soap_note
        
    else:
        return "❌ Please provide either an image or text input.", ""
def create_demo():
    """Create the Gradio demo interface"""
    
    # Sample text for demonstration
    sample_text = """Patient: John Smith, 45yo male
CC: Chest pain
Vitals: BP 140/90, HR 88, RR 16, O2 98%, Temp 98.6F
HPI: Patient reports crushing chest pain x 2 hours, radiating to left arm
PMH: HTN, DM Type 2
Current Meds: Lisinopril 10mg daily, Metformin 500mg BID
PE: Diaphoretic, anxious appearance
EKG: ST elevation in leads II, III, aVF"""
    
    with gr.Blocks(title="Medical OCR SOAP Generator", theme=gr.themes.Soft()) as demo:
        
        gr.HTML("""
        
🏥 Medical OCR SOAP Generator - LIVE DEMO
        🎯 For Competition Judges - Quick 2-Minute Demo:
        
        📋 SAMPLE IMAGE PROVIDED:
        👆 Download "docs-note-to-upload.jpg" from the Files tab above, then upload it below
        OR click "Try Sample Medical Text" button for instant text demo
         
        Demo Steps:
        
        - Upload the sample image (docs-note-to-upload.jpg from Files tab) OR click sample text button
- Click "Generate SOAP Note"
- Wait ~2 minutes for AI processing (first time only)
- See professional SOAP note generated by Gemma 3n
✅ What This Demo Shows:
        
        - Real OCR extraction from handwritten medical notes
- AI-powered medical reasoning with Gemma 3n
- Professional SOAP formatting (Subjective, Objective, Assessment, Plan)
- HIPAA-compliant local processing
⚠️ Note: First generation takes ~2 minutes as model loads. Subsequent ones are faster.
        
        """)
        
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(
                    type="pil", 
                    label="📷 Upload Medical Image",
                    height=300
                )
                
                text_input = gr.Textbox(
                    label="📝 Or Enter Medical Text",
                    placeholder=sample_text,
                    lines=8,
                    max_lines=15
                )
                
                submit_btn = gr.Button(
                    "Generate SOAP Note",
                    variant="primary",
                    size="lg"
                )
                
            with gr.Column():
                extracted_output = gr.Textbox(
                    label="📋 Extracted/Input Text",
                    lines=6,
                    max_lines=10
                )
                
                soap_output = gr.Textbox(
                    label="🏥 Generated SOAP Note",
                    lines=12,
                    max_lines=20
                )
        
        # Example section
        gr.Markdown("### 📋 Quick Test Example")
        example_btn = gr.Button("Try Sample Medical Text", variant="secondary")
        
        def load_example():
            return sample_text, None
            
        example_btn.click(
            load_example,
            outputs=[text_input, image_input]
        )
        
        # Process function
        submit_btn.click(
            process_medical_input,
            inputs=[image_input, text_input],
            outputs=[extracted_output, soap_output]
        )
        
        gr.Markdown("""
        ---
        **About:** This application uses Google's Gemma 3n model for medical text understanding and EasyOCR for handwriting recognition. 
        All processing is done locally for HIPAA compliance.
        
        **Competition Entry:** Medical AI Innovation Challenge 2024
        """)
    
    return demo
# Initialize the application
if __name__ == "__main__":
    print("🚀 Starting Medical OCR SOAP Generator...")
    
    # Load model
    model_loaded = load_model()
    
    if model_loaded:
        print("✅ All systems ready!")
        demo = create_demo()
        demo.launch(
            share=True,
            server_name="0.0.0.0",
            server_port=7860
        )
    else:
        print("❌ Failed to load model. Creating fallback demo...")
        
        def fallback_demo():
            return "❌ Model loading failed. Please check the logs.", "❌ Model not available."
        
        demo = gr.Interface(
            fn=fallback_demo,
            inputs=[
                gr.Image(type="pil", label="Upload Medical Image"),
                gr.Textbox(label="Enter Medical Text", lines=5)
            ],
            outputs=[
                gr.Textbox(label="Status"),
                gr.Textbox(label="Error Message")
            ],
            title="❌ Medical OCR - Model Loading Failed"
        )
        
        demo.launch(share=True)
import io
import subprocess
import sys
import cv2
# Install required packages
def install_packages():
    packages = [
        "transformers",
        "accelerate", 
        "timm",
        "easyocr",
        "opencv-python"
    ]
    for package in packages:
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])
        except:
            print(f"Warning: Could not install {package}")
# Install packages at startup
install_packages()
from transformers import AutoProcessor, AutoModelForImageTextToText, AutoConfig
# Global variables for model
processor = None
model = None
config = None
ocr_reader = None
def load_model():
    """Load the Gemma 3n model"""
    global processor, model, config, ocr_reader
    
    try:
        print("🚀 Loading Gemma 3n model...")
        GEMMA_PATH = "google/gemma-3n-e2b-it"
        
        # Load configuration
        config = AutoConfig.from_pretrained(GEMMA_PATH, trust_remote_code=True)
        print("✅ Config loaded")
        
        # Load processor
        processor = AutoProcessor.from_pretrained(GEMMA_PATH, trust_remote_code=True)
        print("✅ Processor loaded")
        
        # Load model
        model = AutoModelForImageTextToText.from_pretrained(
            GEMMA_PATH,
            config=config,
            torch_dtype="auto",
            device_map="auto",
            trust_remote_code=True
        )
        print("✅ Model loaded successfully!")
        
        # Set up compilation fix
        import torch._dynamo
        torch._dynamo.config.suppress_errors = True
        
        # Initialize OCR
        try:
            import easyocr
            ocr_reader = easyocr.Reader(['en'], gpu=False, verbose=False)
            print("✅ EasyOCR initialized")
        except Exception as e:
            print(f"⚠️ EasyOCR not available: {e}")
            ocr_reader = None
            
        return True
        
    except Exception as e:
        print(f"❌ Model loading failed: {e}")
        return False
def generate_soap_note(text):
    """Generate SOAP note using Gemma 3n"""
    if model is None or processor is None:
        return "❌ Model not loaded. Please wait for initialization."
    
    soap_prompt = f"""You are a medical AI assistant. Convert the following medical notes into a properly formatted SOAP note.
Medical notes:
{text}
Please format as:
S - SUBJECTIVE: (chief complaint, history of present illness, past medical history, medications, allergies)
O - OBJECTIVE: (vital signs, physical examination findings)  
A - ASSESSMENT: (diagnosis/clinical impression)
P - PLAN: (treatment plan, follow-up instructions)
Generate a complete, professional SOAP note:"""
    
    messages = [{
        "role": "system",
        "content": [{"type": "text", "text": "You are an expert medical AI assistant specialized in creating SOAP notes from medical documentation."}]
    }, {
        "role": "user", 
        "content": [{"type": "text", "text": soap_prompt}]
    }]
    
    try:
        inputs = processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt"
        ).to(model.device)
        
        input_len = inputs["input_ids"].shape[-1]
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=400,
                do_sample=True,
                temperature=0.1,
                top_p=0.95,
                pad_token_id=processor.tokenizer.eos_token_id,
                disable_compile=True
            )
        
        response = processor.batch_decode(
            outputs[:, input_len:],
            skip_special_tokens=True
        )[0].strip()
        
        return response
        
    except Exception as e:
        return f"❌ SOAP generation failed: {str(e)}"
def preprocess_image_for_ocr(image):
    """Preprocess image for better OCR results using CLAHE"""
    try:
        if hasattr(image, 'convert'):
            image = image.convert('RGB')
        img_array = np.array(image)
        
        # Convert to grayscale
        if len(img_array.shape) == 3:
            gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
        else:
            gray = img_array
        
        # Resize if too small
        height, width = gray.shape
        if height < 300 or width < 300:
            scale = max(300/height, 300/width)
            new_height = int(height * scale)
            new_width = int(width * scale)
            gray = cv2.resize(gray, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
        
        # Enhance image with CLAHE
        gray = cv2.medianBlur(gray, 3)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        gray = clahe.apply(gray)
        _, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        
        return gray
    except Exception as e:
        print(f"⚠️ Image preprocessing failed: {e}")
        # Fallback to original image if preprocessing fails
        return np.array(image)
def extract_text_from_image(image):
    """Extract text using EasyOCR with CLAHE preprocessing"""
    if ocr_reader is None:
        return "❌ OCR not available"
    
    try:
        # Apply CLAHE preprocessing for better OCR
        processed_img = preprocess_image_for_ocr(image)
        
        results = ocr_reader.readtext(processed_img, detail=0, paragraph=True)
        if results:
            return ' '.join(results).strip()
        else:
            return "❌ No text detected in image"
            
    except Exception as e:
        return f"❌ OCR failed: {str(e)}"
def process_medical_input(image, text):
    """Main processing function for the Gradio interface"""
    
    if image is not None and text.strip():
        return "⚠️ Please provide either an image OR text, not both.", ""
    
    if image is not None:
        # Process image
        print("🔍 Extracting text from image...")
        extracted_text = extract_text_from_image(image)
        
        if extracted_text.startswith('❌'):
            return extracted_text, ""
        
        print("🤖 Generating SOAP note...")
        soap_note = generate_soap_note(extracted_text)
        
        return extracted_text, soap_note
        
    elif text.strip():
        # Process text directly
        print("🤖 Generating SOAP note from text...")
        soap_note = generate_soap_note(text.strip())
        return text.strip(), soap_note
        
    else:
        return "❌ Please provide either an image or text input.", ""
def create_demo():
    """Create the Gradio demo interface"""
    
    # Sample text for demonstration
    sample_text = """Patient: John Smith, 45yo male
CC: Chest pain
Vitals: BP 140/90, HR 88, RR 16, O2 98%, Temp 98.6F
HPI: Patient reports crushing chest pain x 2 hours, radiating to left arm
PMH: HTN, DM Type 2
Current Meds: Lisinopril 10mg daily, Metformin 500mg BID
PE: Diaphoretic, anxious appearance
EKG: ST elevation in leads II, III, aVF"""
    
    with gr.Blocks(title="Medical OCR SOAP Generator", theme=gr.themes.Soft()) as demo:
        
        gr.HTML("""
        🏥 Medical OCR SOAP Generator - LIVE DEMO
        🎯 For Competition Judges - Quick 2-Minute Demo:
        
        📋 SAMPLE IMAGE PROVIDED:
        👆 Download "docs-note-to-upload.jpg" from the Files tab above, then upload it below
        OR click "Try Sample Medical Text" button for instant text demo
         
        Demo Steps:
        
        - Upload the sample image (docs-note-to-upload.jpg from Files tab) OR click sample text button
- Click "Generate SOAP Note"
- Wait ~60-90 seconds for AI processing (first time only)
- See professional SOAP note generated by Gemma 3n
✅ What This Demo Shows:
        
        - Real OCR extraction from handwritten medical notes
- AI-powered medical reasoning with Gemma 3n
- Professional SOAP formatting (Subjective, Objective, Assessment, Plan)
- HIPAA-compliant local processing
⚠️ Note: First generation takes ~60-90 seconds as model loads. Subsequent ones are faster.
        
        """)
        
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(
                    type="pil", 
                    label="📷 Upload Medical Image",
                    height=300
                )
                
                text_input = gr.Textbox(
                    label="📝 Or Enter Medical Text",
                    placeholder=sample_text,
                    lines=8,
                    max_lines=15
                )
                
                submit_btn = gr.Button(
                    "Generate SOAP Note",
                    variant="primary",
                    size="lg"
                )
                
            with gr.Column():
                extracted_output = gr.Textbox(
                    label="📋 Extracted/Input Text",
                    lines=6,
                    max_lines=10
                )
                
                soap_output = gr.Textbox(
                    label="🏥 Generated SOAP Note",
                    lines=12,
                    max_lines=20
                )
        
        # Example section
        gr.Markdown("### 📋 Quick Test Example")
        example_btn = gr.Button("Try Sample Medical Text", variant="secondary")
        
        def load_example():
            return sample_text, None
            
        example_btn.click(
            load_example,
            outputs=[text_input, image_input]
        )
        
        # Process function
        submit_btn.click(
            process_medical_input,
            inputs=[image_input, text_input],
            outputs=[extracted_output, soap_output]
        )
        
        gr.Markdown("""
        ---
        **About:** This application uses Google's Gemma 3n model for medical text understanding and EasyOCR for handwriting recognition. 
        All processing is done locally for HIPAA compliance.
        
        **Competition Entry:** Medical AI Innovation Challenge 2024
        """)
    
    return demo
# Initialize the application
if __name__ == "__main__":
    print("🚀 Starting Medical OCR SOAP Generator...")
    
    # Load model
    model_loaded = load_model()
    
    if model_loaded:
        print("✅ All systems ready!")
        demo = create_demo()
        demo.launch(
            share=True,
            server_name="0.0.0.0",
            server_port=7860
        )
    else:
        print("❌ Failed to load model. Creating fallback demo...")
        
        def fallback_demo():
            return "❌ Model loading failed. Please check the logs.", "❌ Model not available."
        
        demo = gr.Interface(
            fn=fallback_demo,
            inputs=[
                gr.Image(type="pil", label="Upload Medical Image"),
                gr.Textbox(label="Enter Medical Text", lines=5)
            ],
            outputs=[
                gr.Textbox(label="Status"),
                gr.Textbox(label="Error Message")
            ],
            title="❌ Medical OCR - Model Loading Failed"
        )
        
        demo.launch(share=True)