Spaces:
Sleeping
Sleeping
File size: 3,028 Bytes
db0960e 061386e 744e6f4 061386e ea20fb6 061386e 744e6f4 061386e 744e6f4 061386e ea20fb6 78f86c4 ea20fb6 78f86c4 744e6f4 78f86c4 061386e ea20fb6 744e6f4 061386e ea20fb6 1a2c622 061386e 744e6f4 061386e 744e6f4 061386e ea20fb6 744e6f4 061386e ea20fb6 744e6f4 db0960e 744e6f4 061386e ea20fb6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import streamlit as st
import torch
import torch.nn as nn
import timm
import numpy as np
from PIL import Image
import torchvision.transforms as T
from huggingface_hub import hf_hub_download
# ========== Model Definition ==========
class MobileViTSegmentation(nn.Module):
def __init__(self, encoder_name='mobilevit_s', pretrained=True):
super().__init__()
self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained)
self.encoder_channels = self.backbone.feature_info.channels()
self.decoder = nn.Sequential(
nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(32, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
feats = self.backbone(x)
out = self.decoder(feats[-1])
out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
return out
# ========== Load Model ==========
@st.cache_resource
def load_model():
cache_dir = "/tmp/huggingface" # Safe writable directory in HF Spaces
checkpoint_path = hf_hub_download(
repo_id="svsaurav95/ToothSegmentation",
filename="mobilevit_teeth_segmentation.pth",
cache_dir=cache_dir
)
model = MobileViTSegmentation(pretrained=False)
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
model.eval()
return model
model = load_model()
# ========== Image Preprocessing ==========
transform = T.Compose([
T.Resize((256, 256)),
T.ToTensor()
])
# ========== UI ==========
st.set_page_config(page_title="Tooth Segmentation", layout="wide")
st.title("Tooth Segmentation using MobileViT")
uploaded_file = st.file_uploader("Upload a mouth image with visible teeth", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
pred_mask = model(input_tensor)[0, 0].numpy()
# Threshold and resize to original
pred_mask = (pred_mask > 0.7).astype(np.uint8) * 255
pred_mask = Image.fromarray(pred_mask).resize(image.size)
# Create translucent blue overlay
overlay = Image.new("RGBA", image.size, (0, 0, 255, 100))
base = image.convert("RGBA")
pred_mask_rgba = Image.new("L", image.size, 0)
pred_mask_rgba.paste(255, mask=pred_mask)
final = Image.composite(overlay, base, pred_mask_rgba)
# Side-by-side display
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Original Image", use_container_width=True)
with col2:
st.image(final, caption="Tooth Area Segmentation", use_container_width=True)
|