Spaces:
Runtime error
Runtime error
File size: 2,973 Bytes
3819023 ab52a15 3819023 ab52a15 3819023 ab52a15 8b61b70 ab52a15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import pandas as pd
import tempfile
from PIL import Image
from pathlib import Path
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
from shiny.types import FileInfo
import json
import torch
import numpy as np
import os
from transformers import SamModel
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
image_resize_transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor()
])
app_ui = ui.page_fluid(
ui.input_file("file2", "Choose Image", accept=".jpg, .jpeg, .png, .tiff, .tif", multiple=False),
ui.output_image("original_image"),
ui.output_image("image_display")
)
def server(input: Inputs, output: Outputs, session: Session):
@reactive.calc
def loaded_image():
file: list[FileInfo] | None = input.file2()
if file is None:
return None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model2 = SamModel.from_pretrained("facebook/sam-vit-base")
model2.load_state_dict(torch.load('model.pth', map_location=device))
model2.eval()
model2.to(device)
image = Image.open(file[0]["datapath"]).convert('RGB')
transform = image_resize_transform
image_tensor = transform(image).to(device)
with torch.no_grad():
outputs = model2(pixel_values=image_tensor.unsqueeze(0),multimask_output=False)
predicted_masks = outputs.pred_masks.squeeze(1)
predicted_masks = predicted_masks[:, 0, :, :]
mask_tensor = predicted_masks.cpu().detach().squeeze()
mask_array = mask_tensor.numpy()
mask_array = (mask_array * 255).astype(np.uint8)
mask = Image.fromarray(mask_array)
mask = mask.resize((1024, 1024), Image.LANCZOS)
mask = mask.convert('RGBA')
alpha = Image.new('L', mask.size, 128)
mask.putalpha(alpha)
image = Image.open(file[0]["datapath"]).convert('RGB')
image = image.resize((1024, 1024), Image.LANCZOS)
image = image.convert('RGBA')
combined = Image.alpha_composite(image, mask)
combined_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
original_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
image.save(original_file.name, "PNG", quality=100)
mask.save(combined_file.name, "PNG", quality=100)
return original_file.name, combined_file.name
@render.image
def original_image():
result = loaded_image()
if result is None:
return None
img_path, _ = result
return {"src": img_path, "width": "300px"}
@render.image
def image_display():
result = loaded_image()
if result is None:
return None
_, img_path = result
return {"src": img_path, "width": "300px"}
app = App(app_ui, server)
|