Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Form, Depends, Request, File, UploadFile | |
from fastapi.encoders import jsonable_encoder | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import segmentation_models_pytorch as smp | |
import torch | |
import numpy as np | |
import cv2 | |
import os | |
from torch.utils.data import Dataset, DataLoader | |
from PIL import Image | |
from io import BytesIO | |
import traceback | |
import base64 | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = smp.PAN(encoder_name="resnext50_32x4d", in_channels=3, classes=1) | |
model.to(DEVICE).load_state_dict(torch.load("./model/pan_resnext50_32x4d_adam_lr001_batch16_epoch_50.ckpt", map_location=DEVICE)) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Replace with the list of allowed origins for production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class CustomDataset(Dataset): | |
def __init__(self, data, transform=None): | |
self.data = data | |
self.transform = transform | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
sample = { | |
'image': self.data[idx], | |
} | |
if self.transform: | |
sample = self.transform(sample) | |
return sample | |
def combine_images(original_image_np, label_image_np): | |
# Convert label image to grayscale if it's not already | |
if len(label_image_np.shape) > 2: | |
label_image_np = np.mean(label_image_np, axis=2, dtype=np.uint8) | |
# Create a mask where label_image is white (255) | |
mask = label_image_np == 255 | |
# Create an output array initially filled with zeros | |
combined_image_np = np.zeros_like(original_image_np) | |
# Assign original pixels where mask is True (white) | |
combined_image_np[mask] = original_image_np[mask] | |
return combined_image_np | |
async def root(): | |
return {"message": "Hello World"} | |
async def segmentation(file: UploadFile = File(...)): | |
contents = await file.read() | |
image_dataset = [] | |
for file in os.listdir("./images"): | |
image_dataset.append(cv2.resize(cv2.imread('./images/' + file), (160, 544))) | |
image = Image.open(BytesIO(contents)) | |
open_cv_image = np.array(image) | |
open_cv_image = cv2.resize(open_cv_image, (160, 544)) | |
print(type(image_dataset)) | |
image_dataset.insert(0, open_cv_image) | |
image_dataset = np.transpose(image_dataset, (0, 3, 1, 2)) | |
dataset = CustomDataset(image_dataset) | |
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0) | |
try: | |
with torch.no_grad(): | |
for batch in dataloader: | |
temp_image = batch['image'].to(DEVICE).float() | |
print(temp_image.shape) | |
output = model(temp_image) | |
output[0] = (output[0] > 0.5) | |
output = output[0].squeeze().cpu().numpy() | |
output = output * 255 | |
output = output.astype(np.uint8) | |
combined_image_np = combine_images(open_cv_image, output) | |
# combined_image_np = cv2.cvtColor(combined_image_np, cv2.COLOR_BGR2RGB) | |
combined_image_np = Image.fromarray(combined_image_np) | |
buffered = BytesIO() | |
combined_image_np.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
except Exception as e: | |
error_message = traceback.format_exc() | |
return JSONResponse(status_code=500, content={"error": str(e), "traceback": error_message}) | |
else: | |
return JSONResponse(status_code=200, content={"result": 'good', "image": img_str}) | |
async def predict(): | |
return None | |