Spaces:
Running
Running
File size: 6,980 Bytes
3047e70 ba0c28c 3047e70 ba0c28c 3047e70 32cdeb9 3047e70 0fb9984 ba0c28c 0fb9984 3047e70 ba0c28c 0fb9984 3047e70 ba0c28c 3047e70 0fb9984 3047e70 0fb9984 3047e70 ba0c28c 3047e70 4cffc64 0fb9984 4cffc64 0fb9984 68482bc 0fb9984 68482bc ba0c28c 0fb9984 ba0c28c 0fb9984 3047e70 68482bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import gradio as gr
from datetime import datetime
from huggingface_hub import hf_hub_download
import torch
import json
from PIL import Image
from PIL import ImageDraw, ImageFont
import numpy as np
from model import MIPHEIViT
# Load model once
repo_id = "Estabousi/MIPHEI-vit"
model = MIPHEIViT.from_pretrained_hf(repo_id=repo_id)
config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json")
model.eval()
mean = torch.Tensor([0.485, 0.456, 0.406]).to(torch.float32).reshape((-1, 1, 1))
std = torch.Tensor([0.229, 0.224, 0.225]).to(torch.float32).reshape((-1, 1, 1))
with open(config_path, "r") as f:
config = json.load(f)
channel_names = config["targ_channel_names"]
channel_colors = {
"Hoechst": (0, 0, 255), # Blue (DAPI, nuclear stain)
"CD31": (0, 255, 255), # Cyan (endothelial)
"CD45": (255, 255, 0), # Yellow (leukocyte common antigen)
"CD68": (255, 165, 0), # Orange (macrophages)
"CD4": (255, 0, 0), # Red (helper T cells)
"FOXP3": (138, 43, 226), # Purple/Blue-Violet (regulatory T cells)
"CD8a": (303, 100, 100), # Green (cytotoxic T cells)
"CD45RO": (255, 105, 180), # Hot Pink (memory T cells)
"CD20": (0, 191, 255), # Deep Sky Blue (B cells)
"PD-L1": (255, 0, 255), # Magenta
"CD3e": (95, 95, 94), # Crimson (T cells)
"CD163": (184, 134, 11), # Dark Goldenrod (M2 macrophages)
"E-cadherin": (242, 12, 43), # Spring Green (epithelial marker)
"Ki67": (255, 20, 147), # Deep Pink (proliferation marker)
"Pan-CK": (255, 0, 0), # Red (epithelial/carcinoma)
"SMA": (0, 255, 0), # Green (smooth muscle, myofibroblasts)
}
# Contrast correction factors per channel (255 for Hoechst, 150 otherwise)
default_contrast = 150.0
correction_map = {"Hoechst": 255.0, "CD8a": 100, "CD31": 100, "CD4": 100, "CD68": 100, "FOXP3": 100}
max_contrast_correction_value = torch.tensor([
correction_map.get(name, default_contrast) / 255 for name in channel_names
]).reshape(len(channel_names), 1, 1)
overlay_markers = ["Hoechst", "Pan-CK", "SMA", "CD45"]
def preprocess(image):
image = image.convert("RGB").resize((256, 256))
tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255
tensor = (tensor - mean) / std
return tensor.unsqueeze(0) # [1, 3, H, W]
def draw_legend_on_image(image, channel_names, channel_colors, indices, box_size=18, spacing=5, top_margin=5):
"""Draw a semi-transparent legend on the bottom-right of the image."""
overlay = image.convert("RGBA") # to allow alpha
legend_layer = Image.new("RGBA", overlay.size, (255, 255, 255, 0))
draw = ImageDraw.Draw(legend_layer)
font = ImageFont.load_default()
legend_height = top_margin + box_size * len(indices) + spacing * (len(indices) - 1)
legend_width = 60 # adjust as needed
x_start = overlay.width - legend_width - 10
y_start = overlay.height - legend_height - 10
# Semi-transparent background
draw.rectangle(
[x_start - 5, y_start - 5, x_start + legend_width + 5, y_start + legend_height + 5],
fill=(255, 255, 255, 180) # semi-transparent white
)
for i, idx in enumerate(indices):
name = channel_names[idx]
color = channel_colors[name]
y = y_start + i * (box_size + spacing)
draw.rectangle([x_start, y, x_start + box_size, y + box_size], fill=color + (255,))
draw.text((x_start + box_size + 5, y), name, fill=(0, 0, 0, 255), font=font)
# Composite legend onto overlay
combined = Image.alpha_composite(overlay, legend_layer)
return combined.convert("RGB") # back to RGB for display
def merge_colored_images(color_imgs, top4_idx):
# Convert images to float32 NumPy arrays
accum = np.zeros_like(np.array(color_imgs[0]), dtype=np.float32)
for idx in top4_idx:
img = np.array(color_imgs[idx]).astype(np.float32)
accum += img # additive blending
accum = np.clip(accum, 0, 255).astype(np.uint8)
return Image.fromarray(accum, mode='RGB')
def apply_color_map(gray_img, rgb_color):
"""Map a grayscale image to RGB using a fixed pseudocolor."""
gray = np.asarray(gray_img).astype(np.float32) / 255.0
rgb = np.stack([gray * rgb_color[i] for i in range(3)], axis=-1).astype(np.uint8)
return Image.fromarray(rgb, mode='RGB')
def predict(image):
print(f"[{datetime.now().isoformat()}] Inference run")
input_tensor = preprocess(image)
with torch.inference_mode():
output = model(input_tensor)[0] # [16, H, W]
output = (output.clamp(-0.9, 0.9) + 0.9) / 1.8
output_vis = output / max_contrast_correction_value.to(output.device).clamp(min=1e-6)
output_vis = output_vis.clamp(0, 1) * 255
output_vis = np.uint8(output_vis.cpu().numpy())
output = output.cpu().numpy()
# Convert each mIF channel to grayscale PIL image
channel_imgs = []
for i in range(output_vis.shape[0]):
ch_name = channel_names[i]
ch_gray = Image.fromarray(output_vis[i], mode='L')
ch_colored = apply_color_map(ch_gray, channel_colors[ch_name])
channel_imgs.append(ch_colored)
fixed_idx = [channel_names.index(name) for name in overlay_markers]
overlay = merge_colored_images(channel_imgs, fixed_idx)
overlay_with_legend = draw_legend_on_image(overlay, channel_names, channel_colors, fixed_idx)
return [overlay_with_legend] + channel_imgs
# Markdown header
with open("HEADER.md", "r", encoding="utf-8") as f:
HEADER_MD = f.read()
# Build interface using Blocks
with gr.Blocks() as demo:
gr.Markdown(HEADER_MD)
with gr.Row():
# LEFT: input + examples + button
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input H&E")
run_btn = gr.Button("Run Prediction")
gr.Examples(
examples=[
["examples/crc100k_val.jpg"],
["examples/orion_test_1.jpg"],
["examples/orion_test_2.jpg"],
["examples/orion_test_3.jpg"],
["examples/orion_test_4.jpg"],
["examples/orion_test_5.jpg"],
["examples/tcga.jpg"],
["examples/hemit.jpg"],
],
inputs=[input_image],
label="Example H&E tile (TCGA, ORION Test, CRC100K, HEMIT)"
)
# RIGHT: outputs
with gr.Column(scale=2):
overlay_image = gr.Image(type="pil", label="mIF Overlay")
channel_images = [
gr.Image(type="pil", label=f"mIF Channel {channel_names[i]}")
for i in range(16)
]
output_images = [overlay_image] + channel_images
run_btn.click(fn=predict, inputs=input_image, outputs=output_images)
if __name__ == "__main__":
demo.launch(ssr_mode=False)
|