Spaces:
Runtime error
Runtime error
import pandas as pd | |
import tempfile | |
from PIL import Image | |
from pathlib import Path | |
from shiny import App, Inputs, Outputs, Session, reactive, render, ui | |
from shiny.types import FileInfo | |
import json | |
import torch | |
import numpy as np | |
import os | |
from transformers import SamModel | |
import torchvision.transforms as transforms | |
import matplotlib.pyplot as plt | |
image_resize_transform = transforms.Compose([ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor() | |
]) | |
app_ui = ui.page_fluid( | |
ui.input_file("file2", "Choose Image", accept=".jpg, .jpeg, .png, .tiff, .tif", multiple=False), | |
ui.output_image("original_image"), | |
ui.output_image("image_display") | |
) | |
def server(input: Inputs, output: Outputs, session: Session): | |
def loaded_image(): | |
file: list[FileInfo] | None = input.file2() | |
if file is None: | |
return None | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model2 = SamModel.from_pretrained("facebook/sam-vit-base") | |
model2.load_state_dict(torch.load('model.pth', map_location=device)) | |
model2.eval() | |
model2.to(device) | |
image = Image.open(file[0]["datapath"]).convert('RGB') | |
transform = image_resize_transform | |
image_tensor = transform(image).to(device) | |
with torch.no_grad(): | |
outputs = model2(pixel_values=image_tensor.unsqueeze(0),multimask_output=False) | |
predicted_masks = outputs.pred_masks.squeeze(1) | |
predicted_masks = predicted_masks[:, 0, :, :] | |
mask_tensor = predicted_masks.cpu().detach().squeeze() | |
mask_array = mask_tensor.numpy() | |
mask_array = (mask_array * 255).astype(np.uint8) | |
mask = Image.fromarray(mask_array) | |
mask = mask.resize((1024, 1024), Image.LANCZOS) | |
mask = mask.convert('RGBA') | |
alpha = Image.new('L', mask.size, 128) | |
mask.putalpha(alpha) | |
image = Image.open(file[0]["datapath"]).convert('RGB') | |
image = image.resize((1024, 1024), Image.LANCZOS) | |
image = image.convert('RGBA') | |
combined = Image.alpha_composite(image, mask) | |
combined_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
original_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
image.save(original_file.name, "PNG", quality=100) | |
mask.save(combined_file.name, "PNG", quality=100) | |
return original_file.name, combined_file.name | |
def original_image(): | |
result = loaded_image() | |
if result is None: | |
return None | |
img_path, _ = result | |
return {"src": img_path, "width": "300px"} | |
def image_display(): | |
result = loaded_image() | |
if result is None: | |
return None | |
_, img_path = result | |
return {"src": img_path, "width": "300px"} | |
app = App(app_ui, server) | |