SAM_MedTesting / app.py
Axzyl's picture
Upload app.py
1d67d5a verified
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional, List
import base64, io, json
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from fastapi.responses import StreamingResponse
class SegmentRequest(BaseModel):
file_b64: str
input_points: Optional[List[List[int]]] = None
app = FastAPI(title="SAM MedTesting")
MODEL_ID = "facebook/sam-vit-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = SamProcessor.from_pretrained(MODEL_ID)
model = SamModel.from_pretrained(MODEL_ID).to(device)
@app.post("/segment")
async def segment(req: SegmentRequest):
# decode image
img_bytes = base64.b64decode(req.file_b64)
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# prepare inputs
pts = req.input_points
inputs = processor(
img,
input_points=[pts] if pts else None,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
single_mask = masks[0][0] # first image, first mask → torch.Tensor of shape (H, W)
mask_np = (single_mask * 255) \
.to(torch.uint8) \
.cpu().numpy()
# now mask_np.shape == (H, W), e.g. (10, 10)
print(mask_np.shape)
pil_mask = Image.fromarray(mask_np[0,:,:])
buf = io.BytesIO()
pil_mask.save(buf, format="PNG")
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")