januarevan's picture
.
0278799
from fastapi import FastAPI, Form, Depends, Request, File, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import segmentation_models_pytorch as smp
import torch
import numpy as np
import cv2
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from io import BytesIO
import traceback
import base64
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = smp.PAN(encoder_name="resnext50_32x4d", in_channels=3, classes=1)
model.to(DEVICE).load_state_dict(torch.load("./model/pan_resnext50_32x4d_adam_lr001_batch16_epoch_50.ckpt", map_location=DEVICE))
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Replace with the list of allowed origins for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class CustomDataset(Dataset):
def __init__(self, data, transform=None):
self.data = data
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = {
'image': self.data[idx],
}
if self.transform:
sample = self.transform(sample)
return sample
def combine_images(original_image_np, label_image_np):
# Convert label image to grayscale if it's not already
if len(label_image_np.shape) > 2:
label_image_np = np.mean(label_image_np, axis=2, dtype=np.uint8)
# Create a mask where label_image is white (255)
mask = label_image_np == 255
# Create an output array initially filled with zeros
combined_image_np = np.zeros_like(original_image_np)
# Assign original pixels where mask is True (white)
combined_image_np[mask] = original_image_np[mask]
return combined_image_np
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.post("/segmentation")
async def segmentation(file: UploadFile = File(...)):
contents = await file.read()
image_dataset = []
for file in os.listdir("./images"):
image_dataset.append(cv2.resize(cv2.imread('./images/' + file), (160, 544)))
image = Image.open(BytesIO(contents))
open_cv_image = np.array(image)
open_cv_image = cv2.resize(open_cv_image, (160, 544))
print(type(image_dataset))
image_dataset.insert(0, open_cv_image)
image_dataset = np.transpose(image_dataset, (0, 3, 1, 2))
dataset = CustomDataset(image_dataset)
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)
try:
with torch.no_grad():
for batch in dataloader:
temp_image = batch['image'].to(DEVICE).float()
print(temp_image.shape)
output = model(temp_image)
output[0] = (output[0] > 0.5)
output = output[0].squeeze().cpu().numpy()
output = output * 255
output = output.astype(np.uint8)
combined_image_np = combine_images(open_cv_image, output)
# combined_image_np = cv2.cvtColor(combined_image_np, cv2.COLOR_BGR2RGB)
combined_image_np = Image.fromarray(combined_image_np)
buffered = BytesIO()
combined_image_np.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
except Exception as e:
error_message = traceback.format_exc()
return JSONResponse(status_code=500, content={"error": str(e), "traceback": error_message})
else:
return JSONResponse(status_code=200, content={"result": 'good', "image": img_str})
@app.post("/predict")
async def predict():
return None