File size: 3,011 Bytes
36a5d2b
 
 
 
 
 
 
34332a8
36a5d2b
34332a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33e322c
 
34332a8
 
 
 
 
36a5d2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfe1fa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from PIL import Image
import io
import numpy as np
import pydicom
import torch
from fastapi import HTTPException, UploadFile
from skimage.transform import resize
import re

def clean_paragraph(text):
    # Remove leading dashes, numbers, and extra whitespace
    lines = text.strip().splitlines()
    cleaned_lines = [re.sub(r"^\s*[-•\d.]*\s*", "", line) for line in lines if line.strip()]
    return " ".join(cleaned_lines)

def split_report_sections(report_text):
    #strip any * in any place in the report
    report_text = report_text.replace("*", "")
    # Use regex to extract findings and impression sections
    findings_match = re.search(r"(?i)findings:\s*(.*?)(?=(impressions?:))", report_text, re.DOTALL)
    impression_match = re.search(r"(?i)impressions?:\s*(.*)", report_text, re.DOTALL)

    findings_raw = findings_match.group(1).strip() if findings_match else ""
    impression = impression_match.group(1).strip() if impression_match else ""

    findings = clean_paragraph(findings_raw)
    impression = clean_paragraph(impression)


    return {
        "findings": findings,
        "impression": impression  # Keep impression formatting as-is (or you can also clean it similarly)
    }

def load_image(image):
    image = image.convert("RGB")
    image_array = np.asarray(image) / 255.0  # Normalize to [0,1]
    image_array = resize(image_array, (224, 224))
    image_tensor = torch.tensor(image_array, dtype=torch.float32).permute(2, 0, 1)  # CxHxW
    
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    image_tensor = (image_tensor - mean[:, None, None]) / std[:, None, None]
    
    return image_tensor.unsqueeze(0)  # Add batch dimension

async def convert_to_png(file: UploadFile) -> Image.Image:
    """Converts JPG, PNG, or DICOM to a PNG format"""
    image_data = await file.read()
    if file.content_type in ["image/jpeg", "image/png", "image/jpg"]:
        image = Image.open(io.BytesIO(image_data))
        return image
    
    if file.content_type == "application/dicom" or file.filename.endswith(".dcm") or file.filename.endswith(".dicom"):
        dicom_data = pydicom.dcmread(io.BytesIO(image_data))
        pixel_array = dicom_data.pixel_array

        if pixel_array.dtype != np.uint8:
            pixel_array = (pixel_array / pixel_array.max() * 255).astype(np.uint8)
        
        image = Image.fromarray(pixel_array).convert("RGB")
        return image
    
    raise HTTPException(status_code=400, detail="Unsupported media type")

def dicom_to_png(ds):
    pixel_array = ds.pixel_array

    # Normalize if needed (handle 16-bit images)
    if pixel_array.dtype != np.uint8:
        # Scale to 0-255
        pixel_array = pixel_array.astype(np.float32)
        pixel_array -= pixel_array.min()
        pixel_array /= pixel_array.max()
        pixel_array *= 255.0
        pixel_array = pixel_array.astype(np.uint8)

    # Convert to PIL Image
    img = Image.fromarray(pixel_array).convert("RGB")
    return img