File size: 1,778 Bytes
2807856
 
 
 
2474580
2807856
 
2474580
650149e
2807856
 
 
 
 
 
2474580
2807856
2474580
 
650149e
2474580
2807856
 
 
 
 
 
 
2474580
 
2807856
2474580
 
650149e
2474580
 
650149e
2474580
 
 
 
 
1d67d5a
 
 
 
 
 
 
 
2474580
1d67d5a
2474580
 
1d67d5a
2474580
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional, List
import base64, io, json
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from fastapi.responses import StreamingResponse

class SegmentRequest(BaseModel):
    file_b64: str
    input_points: Optional[List[List[int]]] = None

app = FastAPI(title="SAM MedTesting")

MODEL_ID = "facebook/sam-vit-base"
device   = "cuda" if torch.cuda.is_available() else "cpu"
processor = SamProcessor.from_pretrained(MODEL_ID)
model     = SamModel.from_pretrained(MODEL_ID).to(device)

@app.post("/segment")
async def segment(req: SegmentRequest):
    # decode image
    img_bytes = base64.b64decode(req.file_b64)
    img       = Image.open(io.BytesIO(img_bytes)).convert("RGB")

    # prepare inputs
    pts = req.input_points
    inputs = processor(
        img,
        input_points=[pts] if pts else None,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )
    single_mask = masks[0][0]                    # first image, first mask → torch.Tensor of shape (H, W)
    mask_np     = (single_mask * 255)            \
                    .to(torch.uint8)           \
                    .cpu().numpy()             

    # now mask_np.shape == (H, W), e.g. (10, 10)
    print(mask_np.shape)
    pil_mask = Image.fromarray(mask_np[0,:,:])
    buf = io.BytesIO()
    pil_mask.save(buf, format="PNG")
    buf.seek(0)


    return StreamingResponse(buf, media_type="image/png")