File size: 3,028 Bytes
db0960e
061386e
 
 
 
 
744e6f4
 
061386e
ea20fb6
061386e
744e6f4
061386e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744e6f4
061386e
 
ea20fb6
78f86c4
 
 
 
ea20fb6
78f86c4
 
744e6f4
78f86c4
061386e
 
 
 
 
ea20fb6
744e6f4
 
 
 
061386e
ea20fb6
 
1a2c622
061386e
744e6f4
061386e
 
744e6f4
 
 
 
 
061386e
ea20fb6
744e6f4
 
061386e
ea20fb6
 
744e6f4
 
 
 
db0960e
744e6f4
061386e
 
 
 
ea20fb6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import streamlit as st
import torch
import torch.nn as nn
import timm
import numpy as np
from PIL import Image
import torchvision.transforms as T
from huggingface_hub import hf_hub_download

# ========== Model Definition ==========
class MobileViTSegmentation(nn.Module):
    def __init__(self, encoder_name='mobilevit_s', pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained)
        self.encoder_channels = self.backbone.feature_info.channels()

        self.decoder = nn.Sequential(
            nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(32, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        feats = self.backbone(x)
        out = self.decoder(feats[-1])
        out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
        return out

# ========== Load Model ==========
@st.cache_resource
def load_model():
    cache_dir = "/tmp/huggingface"  # Safe writable directory in HF Spaces

    checkpoint_path = hf_hub_download(
        repo_id="svsaurav95/ToothSegmentation",
        filename="mobilevit_teeth_segmentation.pth",
        cache_dir=cache_dir
    )

    model = MobileViTSegmentation(pretrained=False)
    model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
    model.eval()
    return model

model = load_model()

# ========== Image Preprocessing ==========
transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor()
])

# ========== UI ==========
st.set_page_config(page_title="Tooth Segmentation", layout="wide")
st.title("Tooth Segmentation using MobileViT")

uploaded_file = st.file_uploader("Upload a mouth image with visible teeth", type=["jpg", "jpeg", "png"])

if uploaded_file:
    image = Image.open(uploaded_file).convert("RGB")
    input_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        pred_mask = model(input_tensor)[0, 0].numpy()

    # Threshold and resize to original
    pred_mask = (pred_mask > 0.7).astype(np.uint8) * 255
    pred_mask = Image.fromarray(pred_mask).resize(image.size)

    # Create translucent blue overlay
    overlay = Image.new("RGBA", image.size, (0, 0, 255, 100))
    base = image.convert("RGBA")
    pred_mask_rgba = Image.new("L", image.size, 0)
    pred_mask_rgba.paste(255, mask=pred_mask)
    final = Image.composite(overlay, base, pred_mask_rgba)

    # Side-by-side display
    col1, col2 = st.columns(2)
    with col1:
        st.image(image, caption="Original Image", use_container_width=True)
    with col2:
        st.image(final, caption="Tooth Area Segmentation", use_container_width=True)