import streamlit as st import torch import torch.nn as nn import timm import numpy as np from PIL import Image import torchvision.transforms as T from huggingface_hub import hf_hub_download # ========== Model Definition ========== class MobileViTSegmentation(nn.Module): def __init__(self, encoder_name='mobilevit_s', pretrained=True): super().__init__() self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained) self.encoder_channels = self.backbone.feature_info.channels() self.decoder = nn.Sequential( nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1), nn.Upsample(scale_factor=2, mode='bilinear'), nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.Upsample(scale_factor=2, mode='bilinear'), nn.Conv2d(64, 32, kernel_size=3, padding=1), nn.Upsample(scale_factor=2, mode='bilinear'), nn.Conv2d(32, 1, kernel_size=1), nn.Sigmoid() ) def forward(self, x): feats = self.backbone(x) out = self.decoder(feats[-1]) out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False) return out # ========== Load Model ========== @st.cache_resource def load_model(): cache_dir = "/tmp/huggingface" # Safe writable directory in HF Spaces checkpoint_path = hf_hub_download( repo_id="svsaurav95/ToothSegmentation", filename="mobilevit_teeth_segmentation.pth", cache_dir=cache_dir ) model = MobileViTSegmentation(pretrained=False) model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) model.eval() return model model = load_model() # ========== Image Preprocessing ========== transform = T.Compose([ T.Resize((256, 256)), T.ToTensor() ]) # ========== UI ========== st.set_page_config(page_title="Tooth Segmentation", layout="wide") st.title("Tooth Segmentation using MobileViT") uploaded_file = st.file_uploader("Upload a mouth image with visible teeth", type=["jpg", "jpeg", "png"]) if uploaded_file: image = Image.open(uploaded_file).convert("RGB") input_tensor = transform(image).unsqueeze(0) with torch.no_grad(): pred_mask = model(input_tensor)[0, 0].numpy() # Threshold and resize to original pred_mask = (pred_mask > 0.7).astype(np.uint8) * 255 pred_mask = Image.fromarray(pred_mask).resize(image.size) # Create translucent blue overlay overlay = Image.new("RGBA", image.size, (0, 0, 255, 100)) base = image.convert("RGBA") pred_mask_rgba = Image.new("L", image.size, 0) pred_mask_rgba.paste(255, mask=pred_mask) final = Image.composite(overlay, base, pred_mask_rgba) # Side-by-side display col1, col2 = st.columns(2) with col1: st.image(image, caption="Original Image", use_container_width=True) with col2: st.image(final, caption="Tooth Area Segmentation", use_container_width=True)