BiomedParse / main.py
kernel-luso-comfort's picture
Add Apache License 2.0 header to multiple source files
202eff6
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import gradio as gr
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
from modeling.BaseModel import BaseModel
from modeling import build_model
from utilities.distributed import init_distributed
from utilities.arguments import load_opt_from_config_files
from utilities.constants import BIOMED_CLASSES
from inference_utils.inference import interactive_infer_image
def overlay_masks(image, masks, colors):
overlay = image.copy()
overlay = np.array(overlay, dtype=np.uint8)
for mask, color in zip(masks, colors):
overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype(
np.uint8
)
return Image.fromarray(overlay)
def generate_colors(n):
cmap = plt.get_cmap("tab10")
colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)]
return colors
def init_model():
# Download model
model_file = hf_hub_download(
repo_id="microsoft/BiomedParse",
filename="biomedparse_v1.pt",
token=os.getenv("HF_TOKEN"),
)
# Initialize model
conf_files = "configs/biomedparse_inference.yaml"
opt = load_opt_from_config_files([conf_files])
opt = init_distributed(opt)
model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda()
with torch.no_grad():
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
BIOMED_CLASSES + ["background"], is_eval=True
)
return model
def predict(image, prompts):
if not prompts:
return None
# Convert string input to list
prompts = [p.strip() for p in prompts.split(",")]
# Convert to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Get predictions
pred_mask = interactive_infer_image(model, image, prompts)
# Generate visualization
colors = generate_colors(len(prompts))
pred_overlay = overlay_masks(
image, [1 * (pred_mask[i] > 0.5) for i in range(len(prompts))], colors
)
return pred_overlay
def run():
global model
model = init_model()
demo = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(
label="Prompts",
placeholder="Enter prompts separated by commas (e.g., neoplastic cells, inflammatory cells)",
),
],
outputs=gr.Image(type="pil", label="Prediction"),
title="BiomedParse Demo",
description="Upload a biomedical image and enter prompts (separated by commas) to detect specific features.",
examples=[
["examples/144DME_as_F.jpeg", "edema"],
["examples/C3_EndoCV2021_00462.jpg", "polyp"],
["examples/covid_1585.png", "left lung"],
["examples/covid_1585.png", "right lung"],
["examples/covid_1585.png", "COVID-19 infection"],
["examples/ISIC_0015551.jpg", "lesion"],
["examples/LIDC-IDRI-0140_143_280_CT_lung.png", "lung nodule"],
["examples/LIDC-IDRI-0140_143_280_CT_lung.png", "COVID-19 infection"],
[
"examples/Part_1_516_pathology_breast.png",
"connective tissue cells",
],
[
"examples/Part_1_516_pathology_breast.png",
"neoplastic cells",
],
[
"examples/Part_1_516_pathology_breast.png",
"neoplastic cells, inflammatory cells",
],
["examples/T0011.jpg", "optic disc"],
["examples/T0011.jpg", "optic cup"],
["examples/TCGA_HT_7856_19950831_8_MRI-FLAIR_brain.png", "glioma"],
],
)
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
run()