ansal's picture
Update app.py
724d065 verified
raw
history blame
3.21 kB
import os
import tempfile
from pathlib import Path
from PIL import Image
import torch
import numpy as np
import torchvision.transforms as transforms
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
from shiny.types import FileInfo
import base64
from io import BytesIO
from transformers import SamModel
image_resize_transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor()
])
app_ui = ui.page_fluid(
ui.input_file("file2", "Choose Image", accept=".jpg, .jpeg, .png, .tiff, .tif", multiple=False),
ui.output_image("original_image"),
ui.output_image("image_display")
)
def server(input: Inputs, output: Outputs, session: Session):
@reactive.calc
def loaded_image():
file: list[FileInfo] | None = input.file2()
if file is None:
return None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model2 = SamModel.from_pretrained("facebook/sam-vit-base")
model2.load_state_dict(torch.load('model.pth', map_location=device))
model2.eval()
model2.to(device)
image = Image.open(file[0]["datapath"]).convert('RGB')
transform = image_resize_transform
image_tensor = transform(image).to(device)
with torch.no_grad():
outputs = model2(pixel_values=image_tensor.unsqueeze(0), multimask_output=False)
predicted_masks = outputs.pred_masks.squeeze(1)
predicted_masks = predicted_masks[:, 0, :, :]
mask_tensor = predicted_masks.cpu().detach().squeeze()
mask_array = mask_tensor.numpy()
mask_array = (mask_array * 255).astype(np.uint8)
mask = Image.fromarray(mask_array)
mask = mask.resize((1024, 1024), Image.LANCZOS)
mask = mask.convert('RGBA')
alpha = Image.new('L', mask.size, 128)
mask.putalpha(alpha)
image = Image.open(file[0]["datapath"]).convert('RGB')
image = image.resize((1024, 1024), Image.LANCZOS)
image = image.convert('RGBA')
combined = Image.alpha_composite(image, mask)
combined_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
original_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
image.save(original_file.name, "PNG", quality=100)
combined.save(combined_file.name, "PNG", quality=100)
return original_file.name, combined_file.name
def image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
@render.image
def original_image():
result = loaded_image()
if result is None:
return None
img_path, _ = result
return {"src": img_path, "width": "300px"}
@render.image
def image_display():
result = loaded_image()
if result is None:
return None
_, img_path = result
image = Image.open(img_path)
return {"src": image_to_base64(image), "width": "300px"}
app = App(app_ui, server)