from fastapi import FastAPI from pydantic import BaseModel from typing import Optional, List import base64, io, json from PIL import Image import torch from transformers import SamModel, SamProcessor from fastapi.responses import StreamingResponse class SegmentRequest(BaseModel): file_b64: str input_points: Optional[List[List[int]]] = None app = FastAPI(title="SAM MedTesting") MODEL_ID = "facebook/sam-vit-base" device = "cuda" if torch.cuda.is_available() else "cpu" processor = SamProcessor.from_pretrained(MODEL_ID) model = SamModel.from_pretrained(MODEL_ID).to(device) @app.post("/segment") async def segment(req: SegmentRequest): # decode image img_bytes = base64.b64decode(req.file_b64) img = Image.open(io.BytesIO(img_bytes)).convert("RGB") # prepare inputs pts = req.input_points inputs = processor( img, input_points=[pts] if pts else None, return_tensors="pt" ).to(device) with torch.no_grad(): outputs = model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) single_mask = masks[0][0] # first image, first mask → torch.Tensor of shape (H, W) mask_np = (single_mask * 255) \ .to(torch.uint8) \ .cpu().numpy() # now mask_np.shape == (H, W), e.g. (10, 10) print(mask_np.shape) pil_mask = Image.fromarray(mask_np[0,:,:]) buf = io.BytesIO() pil_mask.save(buf, format="PNG") buf.seek(0) return StreamingResponse(buf, media_type="image/png")