CDGPT2-Deployment / utils.py
Ziad Meligy
adding apicall
33e322c
from PIL import Image
import io
import numpy as np
import pydicom
import torch
from fastapi import HTTPException, UploadFile
from skimage.transform import resize
import re
def clean_paragraph(text):
# Remove leading dashes, numbers, and extra whitespace
lines = text.strip().splitlines()
cleaned_lines = [re.sub(r"^\s*[-•\d.]*\s*", "", line) for line in lines if line.strip()]
return " ".join(cleaned_lines)
def split_report_sections(report_text):
#strip any * in any place in the report
report_text = report_text.replace("*", "")
# Use regex to extract findings and impression sections
findings_match = re.search(r"(?i)findings:\s*(.*?)(?=(impressions?:))", report_text, re.DOTALL)
impression_match = re.search(r"(?i)impressions?:\s*(.*)", report_text, re.DOTALL)
findings_raw = findings_match.group(1).strip() if findings_match else ""
impression = impression_match.group(1).strip() if impression_match else ""
findings = clean_paragraph(findings_raw)
impression = clean_paragraph(impression)
return {
"findings": findings,
"impression": impression # Keep impression formatting as-is (or you can also clean it similarly)
}
def load_image(image):
image = image.convert("RGB")
image_array = np.asarray(image) / 255.0 # Normalize to [0,1]
image_array = resize(image_array, (224, 224))
image_tensor = torch.tensor(image_array, dtype=torch.float32).permute(2, 0, 1) # CxHxW
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
image_tensor = (image_tensor - mean[:, None, None]) / std[:, None, None]
return image_tensor.unsqueeze(0) # Add batch dimension
async def convert_to_png(file: UploadFile) -> Image.Image:
"""Converts JPG, PNG, or DICOM to a PNG format"""
image_data = await file.read()
if file.content_type in ["image/jpeg", "image/png", "image/jpg"]:
image = Image.open(io.BytesIO(image_data))
return image
if file.content_type == "application/dicom" or file.filename.endswith(".dcm") or file.filename.endswith(".dicom"):
dicom_data = pydicom.dcmread(io.BytesIO(image_data))
pixel_array = dicom_data.pixel_array
if pixel_array.dtype != np.uint8:
pixel_array = (pixel_array / pixel_array.max() * 255).astype(np.uint8)
image = Image.fromarray(pixel_array).convert("RGB")
return image
raise HTTPException(status_code=400, detail="Unsupported media type")
def dicom_to_png(ds):
pixel_array = ds.pixel_array
# Normalize if needed (handle 16-bit images)
if pixel_array.dtype != np.uint8:
# Scale to 0-255
pixel_array = pixel_array.astype(np.float32)
pixel_array -= pixel_array.min()
pixel_array /= pixel_array.max()
pixel_array *= 255.0
pixel_array = pixel_array.astype(np.uint8)
# Convert to PIL Image
img = Image.fromarray(pixel_array).convert("RGB")
return img