Ansal
change
3819023
raw
history blame
2.97 kB
import pandas as pd
import tempfile
from PIL import Image
from pathlib import Path
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
from shiny.types import FileInfo
import json
import torch
import numpy as np
import os
from transformers import SamModel
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
image_resize_transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor()
])
app_ui = ui.page_fluid(
ui.input_file("file2", "Choose Image", accept=".jpg, .jpeg, .png, .tiff, .tif", multiple=False),
ui.output_image("original_image"),
ui.output_image("image_display")
)
def server(input: Inputs, output: Outputs, session: Session):
@reactive.calc
def loaded_image():
file: list[FileInfo] | None = input.file2()
if file is None:
return None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model2 = SamModel.from_pretrained("facebook/sam-vit-base")
model2.load_state_dict(torch.load('model.pth', map_location=device))
model2.eval()
model2.to(device)
image = Image.open(file[0]["datapath"]).convert('RGB')
transform = image_resize_transform
image_tensor = transform(image).to(device)
with torch.no_grad():
outputs = model2(pixel_values=image_tensor.unsqueeze(0),multimask_output=False)
predicted_masks = outputs.pred_masks.squeeze(1)
predicted_masks = predicted_masks[:, 0, :, :]
mask_tensor = predicted_masks.cpu().detach().squeeze()
mask_array = mask_tensor.numpy()
mask_array = (mask_array * 255).astype(np.uint8)
mask = Image.fromarray(mask_array)
mask = mask.resize((1024, 1024), Image.LANCZOS)
mask = mask.convert('RGBA')
alpha = Image.new('L', mask.size, 128)
mask.putalpha(alpha)
image = Image.open(file[0]["datapath"]).convert('RGB')
image = image.resize((1024, 1024), Image.LANCZOS)
image = image.convert('RGBA')
combined = Image.alpha_composite(image, mask)
combined_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
original_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
image.save(original_file.name, "PNG", quality=100)
mask.save(combined_file.name, "PNG", quality=100)
return original_file.name, combined_file.name
@render.image
def original_image():
result = loaded_image()
if result is None:
return None
img_path, _ = result
return {"src": img_path, "width": "300px"}
@render.image
def image_display():
result = loaded_image()
if result is None:
return None
_, img_path = result
return {"src": img_path, "width": "300px"}
app = App(app_ui, server)