BioMike's picture
Update app.py
315f9a8 verified
raw
history blame
1.84 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from src.model import ClipSegMultiClassModel
from src.config import ClipSegMultiClassConfig
# === Load model ===
class_labels = ["background", "Pig", "Horse", "Sheep"]
label2color = {
0: [0, 0, 0],
1: [255, 0, 0],
2: [0, 255, 0],
3: [0, 0, 255],
}
config = ClipSegMultiClassConfig(
class_labels=class_labels,
label2color=label2color,
model="CIDAS/clipseg-rd64-refined",
)
model = ClipSegMultiClassModel.from_pretrained("BioMike/clipsegmulticlass_v1")
model.eval()
def colorize_mask(mask_tensor, label2color):
mask = mask_tensor.squeeze().cpu().numpy()
h, w = mask.shape
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
for class_id, color in label2color.items():
color_mask[mask == class_id] = color
return color_mask
def segment_with_legend(input_img):
if isinstance(input_img, str):
input_img = Image.open(input_img).convert("RGB")
elif isinstance(input_img, np.ndarray):
input_img = Image.fromarray(input_img).convert("RGB")
pred_mask = model.predict(input_img)
color_mask = colorize_mask(pred_mask, label2color)
overlay = Image.blend(input_img.resize((color_mask.shape[1], color_mask.shape[0])), Image.fromarray(color_mask), alpha=0.5)
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(overlay)
ax.axis("off")
legend_patches = [
plt.Line2D([0], [0], marker='o', color='w',
label=label,
markerfacecolor=np.array(color) / 255.0,
markersize=10)
for label, color in zip(class_labels, label2color.values())
]
ax.legend(handles=legend_patches, loc='lower right', framealpha=0.8)
return fig
if __name__ == "__main__":
demo.launch()