File size: 2,973 Bytes
3819023
ab52a15
 
 
 
 
3819023
ab52a15
 
 
 
 
3819023
ab52a15
 
 
 
 
 
 
 
 
 
8b61b70
 
 
ab52a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import pandas as pd
import tempfile
from PIL import Image
from pathlib import Path
from shiny import App, Inputs, Outputs, Session, reactive, render, ui
from shiny.types import FileInfo
import json
import torch
import numpy as np
import os
from transformers import SamModel
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

image_resize_transform = transforms.Compose([
    transforms.Resize((1024, 1024)), 
    transforms.ToTensor()  
])

app_ui = ui.page_fluid(
    ui.input_file("file2", "Choose Image", accept=".jpg, .jpeg, .png, .tiff, .tif", multiple=False),
    ui.output_image("original_image"),  
    ui.output_image("image_display")    
)

def server(input: Inputs, output: Outputs, session: Session):
    @reactive.calc
    def loaded_image():
        file: list[FileInfo] | None = input.file2()
        if file is None:
            return None
        
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model2 = SamModel.from_pretrained("facebook/sam-vit-base")
        model2.load_state_dict(torch.load('model.pth', map_location=device))
        model2.eval()
        model2.to(device)

        image = Image.open(file[0]["datapath"]).convert('RGB')
        transform = image_resize_transform
        image_tensor = transform(image).to(device)

        
        with torch.no_grad():
            outputs = model2(pixel_values=image_tensor.unsqueeze(0),multimask_output=False)
        predicted_masks = outputs.pred_masks.squeeze(1) 
        predicted_masks = predicted_masks[:, 0, :, :]

        
        mask_tensor = predicted_masks.cpu().detach().squeeze()
        mask_array = mask_tensor.numpy()  
        mask_array = (mask_array * 255).astype(np.uint8)  
        mask = Image.fromarray(mask_array)
        mask = mask.resize((1024, 1024), Image.LANCZOS)
        mask = mask.convert('RGBA')

        
        alpha = Image.new('L', mask.size, 128)
        mask.putalpha(alpha)

        
        image = Image.open(file[0]["datapath"]).convert('RGB')  
        image = image.resize((1024, 1024), Image.LANCZOS)  
        image = image.convert('RGBA')

        combined = Image.alpha_composite(image, mask)


        combined_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
        original_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
        image.save(original_file.name, "PNG", quality=100)
        mask.save(combined_file.name, "PNG", quality=100)
        
        return original_file.name, combined_file.name

    @render.image
    def original_image():
        result = loaded_image()
        if result is None:
            return None
        img_path, _ = result
        return {"src": img_path, "width": "300px"}

    @render.image
    def image_display():
        result = loaded_image()
        if result is None:
            return None
        _, img_path = result
        return {"src": img_path, "width": "300px"}

app = App(app_ui, server)