Spaces:
Running
Running
File size: 5,234 Bytes
b763406 2171052 b763406 a7ebb0c 2171052 b763406 12eeac8 914b995 d20b41e 12eeac8 914b995 12eeac8 b763406 90fd4b5 3b4aa88 2171052 b763406 08587f3 b763406 7dd825e b763406 7dd825e b763406 914b995 d20b41e 914b995 d20b41e 914b995 d20b41e 914b995 d20b41e 914b995 6ee6793 914b995 6ee6793 233abd4 914b995 6ee6793 233abd4 914b995 4707b83 daab25b 2171052 914b995 14f3609 914b995 d473901 914b995 d20b41e 697ee78 914b995 8f1428e d20b41e a1092e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
import random
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
# Caching function setup
def cached_download(*args, **kwargs):
print("Warning: cached_download is deprecated, using hf_hub_download instead.")
return hf_hub_download(*args, **kwargs)
import sys
sys.modules["huggingface_hub.cached_download"] = cached_download
from diffusers import AutoencoderKL, DDPMScheduler
from StableDiffusion.Our_UNet import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from medical_pipeline import MedicalPipeline
from diffusers import DDIMScheduler
from StableDiffusion.Our_Pipe import StableDiffusionPipeline
model_repo_id = "runwayml/stable-diffusion-v1-5"
medsegfactory_id = "JohnWeck/StableDiffusion"
filename = 'checkpoint-300.pth'
device = "cuda" if torch.cuda.is_available() else "cpu"
text_encoder = CLIPTextModel.from_pretrained(model_repo_id, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae")
unet = UNet2DConditionModel.from_config(model_repo_id, subfolder="unet")
medsegfactory_ckpt = hf_hub_download(repo_id=medsegfactory_id, filename=filename)
unet.load_state_dict(torch.load(medsegfactory_ckpt, map_location='cpu'))
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
weight_dtype = torch.float32
unet.to(device, dtype=weight_dtype)
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)
sd_noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
# load SD pipeline
pipe = StableDiffusionPipeline.from_pretrained(
model_repo_id,
torch_dtype=torch.float32,
unet=unet,
scheduler=sd_noise_scheduler,
feature_extractor=None,
safety_checker=None
)
pipeline = MedicalPipeline(pipe, device)
# 定义 keys 与 organ 及 kind 的映射
keys_to_organ_kind = {
"CVC-ClinicDB": {
"organs": ["polyp colonoscopy"],
"kinds": {"polyp colonoscopy": ["polyp"]}
},
"BUSI": {
"organs": ["breast ultrasound"],
"kinds": {"breast ultrasound": ["normal", "breast tumor"]}
},
"LiTS2017": {
"organs": ["abdomen CT scans"],
"kinds": {"abdomen CT scans": ["liver","liver tumor"]}
},
"KiTS2019": {
"organs": ["abdomen CT scans"],
"kinds": {"abdomen CT scans": ["kidney","kidney tumor"]}
},
"ACDC": {
"organs": ["cardiovascular ventricle mri"],
"kinds": {"cardiovascular ventricle mri": ["right ventricle", "myocardium","left ventricle"]}
},
"AMOS2022": {
"organs": ["abdomen CT scans"],
"kinds": {"abdomen CT scans": ["liver", "right kidney", "spleen", "pancreas", "aorta", "inferior vena cava",
"right adrenal gland", "left adrenal gland", "gall bladder", "esophagus", "stomach", "duodenum", "left kidney",
"bladder", "prostate"]}
}
}
def update_organ_and_kind(selected_key):
"""更新 organ 和 kind 的选项,并确保 organ 正确更新"""
organs = keys_to_organ_kind[selected_key]["organs"]
first_organ = organs[0] if organs else "" # 选择第一个 organ
kinds = keys_to_organ_kind[selected_key]["kinds"].get(first_organ, []) # 确保 kinds 不为空
return gr.update(choices=organs, value=first_organ), gr.update(choices=kinds, value=kinds)
def update_kind(selected_key, selected_organ):
"""更新 kind 的选项"""
kinds = keys_to_organ_kind[selected_key]["kinds"].get(selected_organ, [])
return gr.update(choices=kinds, value=kinds)
def generate_image(organ, kinds, keys):
kind = ",".join(kinds)
print(f"Debug Info -> Organ: {organ}, Kind: {kind}, Keys: {keys}")
image, label = pipeline.generate(organ=organ, kind=kind, keys=keys)
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(label)
plt.axis('off')
plt.savefig('pred.png', bbox_inches='tight', pad_inches=0)
return "pred.png"
with gr.Blocks() as demo:
gr.Markdown("### 📌 Note: This app is running under free CPU. Make sure to provide reasonable kind combinations.")
keys_dropdown = gr.Dropdown(list(keys_to_organ_kind.keys()), label="Keys", value="CVC-ClinicDB")
organ_dropdown = gr.Dropdown(keys_to_organ_kind["CVC-ClinicDB"]["organs"], label="Organ", value=keys_to_organ_kind["CVC-ClinicDB"]["organs"][0])
kind_checkbox = gr.CheckboxGroup(keys_to_organ_kind["CVC-ClinicDB"]["kinds"]["polyp colonoscopy"], label="Kind")
# Update organ and kind based on key
keys_dropdown.change(update_organ_and_kind, inputs=keys_dropdown, outputs=[organ_dropdown, kind_checkbox])
organ_dropdown.change(update_kind, inputs=[keys_dropdown, organ_dropdown], outputs=kind_checkbox)
generate_button = gr.Button("Generate Image")
output_image = gr.Image(label="Visualization")
generate_button.click(generate_image, inputs=[organ_dropdown, kind_checkbox, keys_dropdown], outputs=output_image)
demo.launch()
|