Spaces:
Sleeping
Sleeping
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import gradio as gr | |
import torch | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from huggingface_hub import hf_hub_download | |
from modeling.BaseModel import BaseModel | |
from modeling import build_model | |
from utilities.distributed import init_distributed | |
from utilities.arguments import load_opt_from_config_files | |
from utilities.constants import BIOMED_CLASSES | |
from inference_utils.inference import interactive_infer_image | |
def overlay_masks(image, masks, colors): | |
overlay = image.copy() | |
overlay = np.array(overlay, dtype=np.uint8) | |
for mask, color in zip(masks, colors): | |
overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype( | |
np.uint8 | |
) | |
return Image.fromarray(overlay) | |
def generate_colors(n): | |
cmap = plt.get_cmap("tab10") | |
colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)] | |
return colors | |
def init_model(): | |
# Download model | |
model_file = hf_hub_download( | |
repo_id="microsoft/BiomedParse", | |
filename="biomedparse_v1.pt", | |
token=os.getenv("HF_TOKEN"), | |
) | |
# Initialize model | |
conf_files = "configs/biomedparse_inference.yaml" | |
opt = load_opt_from_config_files([conf_files]) | |
opt = init_distributed(opt) | |
model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda() | |
with torch.no_grad(): | |
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings( | |
BIOMED_CLASSES + ["background"], is_eval=True | |
) | |
return model | |
def predict(image, prompts): | |
if not prompts: | |
return None | |
# Convert string input to list | |
prompts = [p.strip() for p in prompts.split(",")] | |
# Convert to RGB if needed | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Get predictions | |
pred_mask = interactive_infer_image(model, image, prompts) | |
# Generate visualization | |
colors = generate_colors(len(prompts)) | |
pred_overlay = overlay_masks( | |
image, [1 * (pred_mask[i] > 0.5) for i in range(len(prompts))], colors | |
) | |
return pred_overlay | |
def run(): | |
global model | |
model = init_model() | |
demo = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Image(type="pil", label="Input Image"), | |
gr.Textbox( | |
label="Prompts", | |
placeholder="Enter prompts separated by commas (e.g., neoplastic cells, inflammatory cells)", | |
), | |
], | |
outputs=gr.Image(type="pil", label="Prediction"), | |
title="BiomedParse Demo", | |
description="Upload a biomedical image and enter prompts (separated by commas) to detect specific features.", | |
examples=[ | |
["examples/144DME_as_F.jpeg", "edema"], | |
["examples/C3_EndoCV2021_00462.jpg", "polyp"], | |
["examples/covid_1585.png", "left lung"], | |
["examples/covid_1585.png", "right lung"], | |
["examples/covid_1585.png", "COVID-19 infection"], | |
["examples/ISIC_0015551.jpg", "lesion"], | |
["examples/LIDC-IDRI-0140_143_280_CT_lung.png", "lung nodule"], | |
["examples/LIDC-IDRI-0140_143_280_CT_lung.png", "COVID-19 infection"], | |
[ | |
"examples/Part_1_516_pathology_breast.png", | |
"connective tissue cells", | |
], | |
[ | |
"examples/Part_1_516_pathology_breast.png", | |
"neoplastic cells", | |
], | |
[ | |
"examples/Part_1_516_pathology_breast.png", | |
"neoplastic cells, inflammatory cells", | |
], | |
["examples/T0011.jpg", "optic disc"], | |
["examples/T0011.jpg", "optic cup"], | |
["examples/TCGA_HT_7856_19950831_8_MRI-FLAIR_brain.png", "glioma"], | |
], | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |
if __name__ == "__main__": | |
run() | |