| | import os |
| | import numpy as np |
| | import torch |
| | import SimpleITK as sitk |
| | import torch.nn.functional as F |
| | import cv2 |
| | from PIL import Image, ImageDraw, ImageOps |
| | import tempfile |
| | import gradio as gr |
| | from segment_anything.build_sam3D import sam_model_registry3D |
| | from utils.click_method import get_next_click3D_torch_ritm, get_next_click3D_torch_2 |
| |
|
| |
|
| | def build_model(): |
| | checkpoint_path = 'ckpt\\BoSAM.pth' |
| | |
| | checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False) |
| | |
| | state_dict = checkpoint['model_state_dict'] |
| | |
| | sam_model = sam_model_registry3D['vit_b_ori'](checkpoint=None).to('cuda') |
| | sam_model.load_state_dict(state_dict) |
| | |
| | return sam_model |
| |
|
| |
|
| | def center_crop_or_pad(image_array, target_shape=(128, 128, 128)): |
| | """中心裁剪或填充图像到目标尺寸""" |
| | current_shape = image_array.shape |
| | |
| | start = [(c - t) // 2 if c > t else 0 for c, t in zip(current_shape, target_shape)] |
| | end = [s + t if c > t else c for s, t, c in zip(start, target_shape, current_shape)] |
| | |
| | result = np.zeros(target_shape, dtype=image_array.dtype) |
| | |
| | target_start = [0 if c > t else (t - c) // 2 for c, t in zip(current_shape, target_shape)] |
| | target_end = [t if c > t else ts + c for ts, c, t in zip(target_start, current_shape, target_shape)] |
| | |
| | if all(c >= t for c, t in zip(current_shape, target_shape)): |
| | cropped = image_array[ |
| | start[0]:start[0]+target_shape[0], |
| | start[1]:start[1]+target_shape[1], |
| | start[2]:start[2]+target_shape[2] |
| | ] |
| | return cropped |
| | else: |
| | source_slices = tuple(slice(0 if c <= t else s, c if c <= t else e) |
| | for s, e, c, t in zip(start, end, current_shape, target_shape)) |
| | target_slices = tuple(slice(ts, te) |
| | for ts, te in zip(target_start, target_end)) |
| | |
| | result[target_slices] = image_array[source_slices] |
| | return result |
| |
|
| |
|
| | def preprocess_image(image_path): |
| | """预处理图像为128x128x128""" |
| | image = sitk.ReadImage(image_path) |
| | image_array = sitk.GetArrayFromImage(image) |
| | |
| | processed_array = center_crop_or_pad(image_array, (128, 128, 128)) |
| | |
| | image_tensor = torch.tensor(processed_array).float().unsqueeze(0).unsqueeze(0) |
| | |
| | return image_tensor.to('cuda') |
| |
|
| |
|
| | def load_gt3d(image3d_path): |
| | """加载并预处理GT标签为128x128x128""" |
| | gt3d_path = r'examples\labels\1.3.6.1.4.1.9328.50.4.0357.nii.gz' |
| | if not os.path.exists(gt3d_path): |
| | raise FileNotFoundError(f"The file {gt3d_path} does not exist.") |
| | |
| | image = sitk.ReadImage(gt3d_path) |
| | image_array = sitk.GetArrayFromImage(image) |
| | |
| | processed_array = center_crop_or_pad(image_array, (128, 128, 128)) |
| | |
| | gt_tensor = torch.tensor(processed_array).float().unsqueeze(0).unsqueeze(0) |
| | |
| | return gt_tensor.to('cuda') |
| |
|
| |
|
| | def overlay_mask_on_image(image_slice, mask_slice, alpha=0.6): |
| | """在图像切片上叠加掩码,增强视觉效果""" |
| | |
| | p2, p98 = np.percentile(image_slice, (2, 98)) |
| | image_contrast = np.clip((image_slice - p2) / (p98 - p2), 0, 1) |
| | image_contrast = (image_contrast * 255).astype(np.uint8) |
| | |
| | |
| | image_rgb = Image.fromarray(image_contrast).convert("RGB") |
| | |
| | |
| | enhancer = ImageOps.autocontrast(image_rgb, cutoff=0.5) |
| | image_rgba = enhancer.convert("RGBA") |
| |
|
| | |
| | mask_image = Image.new('RGBA', image_rgba.size, (0, 0, 0, 0)) |
| | mask_draw = ImageDraw.Draw(mask_image) |
| |
|
| | mask = (mask_slice > 0).astype(np.uint8) * 255 |
| | mask_pil = Image.fromarray(mask, mode='L') |
| | |
| | |
| | mask_draw.bitmap((0, 0), mask_pil, fill=(41, 128, 255, int(255 * alpha))) |
| |
|
| | |
| | combined_image = Image.alpha_composite(image_rgba, mask_image) |
| | |
| | return combined_image |
| |
|
| |
|
| | def predict(image3D, sam_model, points=None, prev_masks=None, num_clicks=5): |
| | """使用SAM模型预测掩码""" |
| | sam_model.eval() |
| | |
| | image3D = (image3D - image3D.mean()) / image3D.std() |
| | |
| | gt3D = load_gt3d(None) |
| | |
| | if prev_masks is None: |
| | prev_masks = torch.zeros_like(image3D).to('cuda') |
| | |
| | low_res_masks = F.interpolate(prev_masks.float(), size=(32, 32, 32)) |
| | |
| | with torch.no_grad(): |
| | image_embedding = sam_model.image_encoder(image3D) |
| | |
| | for num_click in range(num_clicks): |
| | with torch.no_grad(): |
| | batch_points, batch_labels = get_next_click3D_torch_2(prev_masks.to('cuda'), gt3D.to('cuda')) |
| | |
| | points_co = torch.cat(batch_points, dim=0).to('cuda') |
| | points_la = torch.cat(batch_labels, dim=0).to('cuda') |
| | |
| | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( |
| | points=[points_co, points_la], |
| | boxes=None, |
| | masks=low_res_masks.to('cuda'), |
| | ) |
| | |
| | low_res_masks, iou_predictions = sam_model.mask_decoder( |
| | image_embeddings=image_embedding.to('cuda'), |
| | image_pe=sam_model.prompt_encoder.get_dense_pe(), |
| | sparse_prompt_embeddings=sparse_embeddings, |
| | dense_prompt_embeddings=dense_embeddings, |
| | multimask_output=False, |
| | ) |
| | |
| | prev_masks = F.interpolate(low_res_masks, size=[128, 128, 128], mode='trilinear', align_corners=False) |
| | |
| | medsam_seg_prob = torch.sigmoid(prev_masks) |
| | medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze() |
| | medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8) |
| | |
| | return medsam_seg, medsam_seg_prob |
| |
|
| |
|
| | def normalize_image(image): |
| | """增强图像对比度""" |
| | |
| | p2, p98 = np.percentile(image, (2, 98)) |
| | if p98 - p2 != 0: |
| | image = np.clip((image - p2) / (p98 - p2), 0, 1) |
| | else: |
| | image = np.zeros_like(image) |
| | image = (image * 255).astype(np.uint8) |
| | return image |
| |
|
| |
|
| | def predicts(img_path, sam_model): |
| | """预处理图像并预测""" |
| | img = preprocess_image(img_path) |
| | prediction, prediction_prob = predict(img, sam_model) |
| | return prediction, prediction_prob |
| |
|
| |
|
| | def save_nifti(prediction, original_image_path): |
| | """保存预测结果为NIFTI文件""" |
| | original_image = sitk.ReadImage(original_image_path) |
| | |
| | output_image = sitk.GetImageFromArray(prediction.astype(np.uint8)) |
| | |
| | output_image.SetDirection(original_image.GetDirection()) |
| | output_image.SetOrigin(original_image.GetOrigin()) |
| | |
| | original_size = original_image.GetSize() |
| | original_spacing = original_image.GetSpacing() |
| | |
| | new_spacing = [ |
| | original_spacing[0] * (original_size[0] / 128), |
| | original_spacing[1] * (original_size[1] / 128), |
| | original_spacing[2] * (original_size[2] / 128) |
| | ] |
| | output_image.SetSpacing(new_spacing) |
| | |
| | temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".nii.gz") |
| | temp_filename = temp_file.name |
| | |
| | sitk.WriteImage(output_image, temp_filename) |
| | |
| | return temp_filename |
| |
|
| |
|
| | def gr_interface(img_path, sam_model=None): |
| | """增强的Gradio界面函数""" |
| | if sam_model is None: |
| | sam_model = build_model() |
| | |
| | |
| | yield None, gr.update(value="正在加载数据..."), None, None, None |
| | |
| | processed_img = preprocess_image(img_path) |
| | |
| | yield None, gr.update(value="正在分割..."), None, None, None |
| | |
| | prediction, prediction_prob = predicts(img_path, sam_model) |
| | |
| | yield None, gr.update(value="正在生成可视化..."), None, None, None |
| | |
| | processed_slices = [] |
| | combined_slices = [] |
| | predicted_slices = [] |
| | |
| | nifti_file_path = save_nifti(prediction, img_path) |
| | |
| | |
| | start_idx = (128 - 32) // 2 |
| | end_idx = start_idx + 32 |
| | |
| | for i in range(start_idx, end_idx): |
| | |
| | processed_slice = processed_img[0, 0, i].cpu().numpy() |
| | processed_slices.append(normalize_image(processed_slice)) |
| | |
| | |
| | mask_slice = prediction[i] |
| | normalized_mask = normalize_image(mask_slice) |
| | |
| | |
| | combined_image = overlay_mask_on_image(processed_slices[-1], mask_slice) |
| | combined_slices.append(combined_image) |
| | |
| | |
| | predicted_slices.append(normalized_mask) |
| | |
| | yield processed_slices, gr.update(value="分割完成!"), combined_slices, predicted_slices, nifti_file_path |
| |
|
| |
|
| | |
| | DEFAULT_EXAMPLE = "examples\\1.3.6.1.4.1.9328.50.4.0327.nii.gz" |
| | EXAMPLES = [ |
| | ["examples\\1.3.6.1.4.1.9328.50.4.0327.nii.gz"], |
| | ["examples\\1.3.6.1.4.1.9328.50.4.0357.nii.gz"], |
| | ["examples\\1.3.6.1.4.1.9328.50.4.0477.nii.gz"], |
| | ["examples\\1.3.6.1.4.1.9328.50.4.0491.nii.gz"], |
| | ["examples\\1.3.6.1.4.1.9328.50.4.0708.nii.gz"], |
| | ["examples\\1.3.6.1.4.1.9328.50.4.0719.nii.gz"] |
| | ] |
| |
|
| | |
| | css = """ |
| | body { |
| | background-color: #f8fafc; |
| | } |
| | |
| | .container { |
| | max-width: 1200px; |
| | margin: 0 auto; |
| | } |
| | |
| | .main-title { |
| | text-align: center; |
| | color: #2563eb; |
| | font-size: 2.5rem; |
| | margin-bottom: 1rem; |
| | font-weight: bold; |
| | animation: fadeIn 1.5s ease-in-out; |
| | } |
| | |
| | .sub-title { |
| | text-align: center; |
| | color: #1e293b; |
| | margin-bottom: 2rem; |
| | animation: fadeIn 2s ease-in-out; |
| | } |
| | |
| | .custom-button { |
| | background-color: #2563eb !important; |
| | color: white !important; |
| | transition: transform 0.3s, box-shadow 0.3s; |
| | } |
| | |
| | .custom-button:hover { |
| | transform: translateY(-2px); |
| | box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15); |
| | } |
| | |
| | .gallery-item { |
| | border-radius: 8px; |
| | overflow: hidden; |
| | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); |
| | transition: transform 0.3s; |
| | } |
| | |
| | .gallery-item:hover { |
| | transform: scale(1.02); |
| | box-shadow: 0 6px 12px rgba(0, 0, 0, 0.15); |
| | } |
| | |
| | @keyframes fadeIn { |
| | from { opacity: 0; transform: translateY(20px); } |
| | to { opacity: 1; transform: translateY(0); } |
| | } |
| | """ |
| |
|
| | |
| | sam_model = build_model() |
| |
|
| | |
| | with gr.Blocks(title="3D医学影像智能分割系统", css=css) as demo: |
| | gr.HTML("<h1 class='main-title'>3D医学影像智能分割系统</h1>") |
| | gr.HTML("<p class='sub-title'>基于BoSAM的前沿人工智能自动分割技术,为医学影像分析提供高精度解决方案</p>") |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | |
| | gr.Markdown("### 上传/选择影像") |
| | input_file = gr.File(label="上传NIfTI文件", value=DEFAULT_EXAMPLE) |
| | |
| | status = gr.Textbox(label="处理状态", value="准备就绪") |
| | process_btn = gr.Button("开始智能分割", elem_classes=["custom-button"]) |
| | |
| | |
| | gr.Markdown("### 示例数据") |
| | examples = gr.Examples( |
| | examples=EXAMPLES, |
| | inputs=[input_file] |
| | ) |
| | |
| | gr.HTML(""" |
| | <div style="margin-top: 2rem; padding: 1rem; background-color: rgba(16, 185, 129, 0.1); border-radius: 8px;"> |
| | <h3 style="color: #10b981; margin-bottom: 0.5rem;">技术亮点</h3> |
| | <ul style="margin-left: 1.5rem;"> |
| | <li>基于最新的Segment Anything Model (SAM) 技术</li> |
| | <li>专为3D医学影像优化的深度学习模型</li> |
| | <li>智能识别解剖结构,无需手动绘制边界</li> |
| | <li>高精度分割结果,提升诊断效率</li> |
| | </ul> |
| | </div> |
| | """) |
| |
|
| | with gr.Column(scale=2): |
| | |
| | with gr.Row(): |
| | gr.Markdown("## 原始医学影像") |
| | output_original = gr.Gallery(label="", show_label=False, columns=4, rows=8, height="600px", elem_classes=["gallery-item"]) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | gr.Markdown("## 分割叠加结果") |
| | output_combined = gr.Gallery(label="", show_label=False, columns=4, rows=4, height="400px", elem_classes=["gallery-item"]) |
| | |
| | with gr.Column(): |
| | gr.Markdown("## 分割掩码") |
| | output_mask = gr.Gallery(label="", show_label=False, columns=4, rows=4, height="400px", elem_classes=["gallery-item"]) |
| | |
| | gr.Markdown("## 分割结果下载") |
| | output_file = gr.File(label="下载完整3D分割结果 (NIFTI格式)") |
| | |
| | gr.HTML(""" |
| | <div style="text-align: center; margin-top: 2rem; padding: 1rem; border-top: 1px solid rgba(0, 0, 0, 0.1);"> |
| | <p>© 2025 3D医学影像智能分割系统 | 人工智能辅助医学影像分析平台</p> |
| | <p>基于最新的BoaSAM模型,为医疗影像分析提供高精度自动分割解决方案</p> |
| | </div> |
| | """) |
| | |
| | |
| | process_btn.click( |
| | fn=gr_interface, |
| | inputs=[input_file], |
| | outputs=[output_original, status, output_combined, output_mask, output_file] |
| | ) |
| | |
| | examples.dataset.click( |
| | fn=gr_interface, |
| | inputs=[input_file], |
| | outputs=[output_original, status, output_combined, output_mask, output_file] |
| | ) |
| | |
| | demo.load( |
| | fn=gr_interface, |
| | inputs=[input_file], |
| | outputs=[output_original, status, output_combined, output_mask, output_file] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(debug=True, share = True) |