File size: 5,234 Bytes
b763406
2171052
b763406
 
a7ebb0c
2171052
b763406
12eeac8
914b995
d20b41e
12eeac8
 
 
914b995
12eeac8
 
 
b763406
 
 
 
 
 
 
 
90fd4b5
 
3b4aa88
2171052
b763406
 
 
08587f3
 
 
 
b763406
 
 
7dd825e
b763406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dd825e
b763406
 
 
 
 
 
 
 
914b995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d20b41e
 
 
914b995
d20b41e
 
 
914b995
d20b41e
 
914b995
 
d20b41e
914b995
 
 
 
6ee6793
914b995
6ee6793
 
233abd4
914b995
 
 
6ee6793
 
233abd4
914b995
 
4707b83
daab25b
 
 
 
 
 
 
 
 
 
 
2171052
914b995
14f3609
914b995
d473901
914b995
 
d20b41e
697ee78
 
914b995
8f1428e
 
 
d20b41e
a1092e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import random
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import gradio as gr
from huggingface_hub import hf_hub_download

# Caching function setup
def cached_download(*args, **kwargs):
    print("Warning: cached_download is deprecated, using hf_hub_download instead.")
    return hf_hub_download(*args, **kwargs)

import sys
sys.modules["huggingface_hub.cached_download"] = cached_download

from diffusers import AutoencoderKL, DDPMScheduler
from StableDiffusion.Our_UNet import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from medical_pipeline import MedicalPipeline
from diffusers import DDIMScheduler
from StableDiffusion.Our_Pipe import StableDiffusionPipeline

model_repo_id = "runwayml/stable-diffusion-v1-5"
medsegfactory_id = "JohnWeck/StableDiffusion"
filename = 'checkpoint-300.pth'

device = "cuda" if torch.cuda.is_available() else "cpu"
text_encoder = CLIPTextModel.from_pretrained(model_repo_id, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae")
unet = UNet2DConditionModel.from_config(model_repo_id, subfolder="unet")

medsegfactory_ckpt = hf_hub_download(repo_id=medsegfactory_id, filename=filename)
unet.load_state_dict(torch.load(medsegfactory_ckpt, map_location='cpu'))

vae.requires_grad_(False)
text_encoder.requires_grad_(False)

weight_dtype = torch.float32

unet.to(device, dtype=weight_dtype)
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)

sd_noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)

# load SD pipeline
pipe = StableDiffusionPipeline.from_pretrained(
    model_repo_id,
    torch_dtype=torch.float32,
    unet=unet,
    scheduler=sd_noise_scheduler,
    feature_extractor=None,
    safety_checker=None
)

pipeline = MedicalPipeline(pipe, device)

# 定义 keys 与 organ 及 kind 的映射
keys_to_organ_kind = {
    "CVC-ClinicDB": {
        "organs": ["polyp colonoscopy"],
        "kinds": {"polyp colonoscopy": ["polyp"]}
    },
    "BUSI": {
        "organs": ["breast ultrasound"],
        "kinds": {"breast ultrasound": ["normal", "breast tumor"]}
    },
    "LiTS2017": {
        "organs": ["abdomen CT scans"],
        "kinds": {"abdomen CT scans": ["liver","liver tumor"]}
    },
    "KiTS2019": {
        "organs": ["abdomen CT scans"],
        "kinds": {"abdomen CT scans": ["kidney","kidney tumor"]}
    },
    "ACDC": {
        "organs": ["cardiovascular ventricle mri"],
        "kinds": {"cardiovascular ventricle mri": ["right ventricle", "myocardium","left ventricle"]}
    },
    "AMOS2022": {
        "organs": ["abdomen CT scans"],
        "kinds": {"abdomen CT scans": ["liver", "right kidney", "spleen", "pancreas", "aorta", "inferior vena cava",
                         "right adrenal gland", "left adrenal gland", "gall bladder", "esophagus", "stomach", "duodenum", "left kidney",
                         "bladder", "prostate"]}
    }
}


def update_organ_and_kind(selected_key):
    """更新 organ 和 kind 的选项,并确保 organ 正确更新"""
    organs = keys_to_organ_kind[selected_key]["organs"]
    first_organ = organs[0] if organs else ""  # 选择第一个 organ
    kinds = keys_to_organ_kind[selected_key]["kinds"].get(first_organ, [])  # 确保 kinds 不为空
    return gr.update(choices=organs, value=first_organ), gr.update(choices=kinds, value=kinds)


def update_kind(selected_key, selected_organ):
    """更新 kind 的选项"""
    kinds = keys_to_organ_kind[selected_key]["kinds"].get(selected_organ, [])
    return gr.update(choices=kinds, value=kinds)


def generate_image(organ, kinds, keys):
    kind = ",".join(kinds)
    print(f"Debug Info -> Organ: {organ}, Kind: {kind}, Keys: {keys}")
    image, label = pipeline.generate(organ=organ, kind=kind, keys=keys)
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.axis('off')
    plt.subplot(1, 2, 2)
    plt.imshow(label)
    plt.axis('off')
    plt.savefig('pred.png', bbox_inches='tight', pad_inches=0)
    return "pred.png"

with gr.Blocks() as demo:
    gr.Markdown("### 📌 Note: This app is running under free CPU. Make sure to provide reasonable kind combinations.")
    keys_dropdown = gr.Dropdown(list(keys_to_organ_kind.keys()), label="Keys", value="CVC-ClinicDB")
    organ_dropdown = gr.Dropdown(keys_to_organ_kind["CVC-ClinicDB"]["organs"], label="Organ", value=keys_to_organ_kind["CVC-ClinicDB"]["organs"][0])
    kind_checkbox = gr.CheckboxGroup(keys_to_organ_kind["CVC-ClinicDB"]["kinds"]["polyp colonoscopy"], label="Kind")

    # Update organ and kind based on key
    keys_dropdown.change(update_organ_and_kind, inputs=keys_dropdown, outputs=[organ_dropdown, kind_checkbox])
    organ_dropdown.change(update_kind, inputs=[keys_dropdown, organ_dropdown], outputs=kind_checkbox)

    generate_button = gr.Button("Generate Image")
    output_image = gr.Image(label="Visualization")
    generate_button.click(generate_image, inputs=[organ_dropdown, kind_checkbox, keys_dropdown], outputs=output_image)

demo.launch()