File size: 3,963 Bytes
460258c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
import vision_transformer as models
import cv2
from torch import nn
from utils import load_pretrained_weights


class PatchEmbedding:
    """
    该类加载了预训练的VIT_Base模型,可以对输入图像生成图像的patch token。
     Args:
        pretrained_weights (str): 预训练权重文件的路径。
        arch (str, optional): 模型使用的体系结构。默认为“vit_base”。
        patch_size (int, optional): 图像中提取的patch的大小。默认值为16。
     Attributes:
        model: 图像嵌入模型。
        embed_dim (int): 图像嵌入的维度。
     Methods:
        load_pretrained_weights(pretrained_weights): 载入预训练的权重到模型中。
        get_representations(image_path, tfms, denormalize): 为输入图像生成patch token。
    """
    def __init__(self, pretrained_weights, arch='vit_base', patch_size=16):
        self.model = models.__dict__[arch](patch_size=patch_size, num_classes=0)
        self.embed_dim = self.model.embed_dim
        self.model.eval().requires_grad_(False)
        self.load_pretrained_weights(pretrained_weights)
        
        from torchvision import transforms
        IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
        IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
 

        self.tfms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
        ])
        
        
    def load_pretrained_weights(self, pretrained_weights):
        load_pretrained_weights(self.model, pretrained_weights)
        
    def get_representation(self, image):
        """
        生成输入图像的patch token。
         Args:
            image_path (str): 输入图像的路径。
         Returns:
            patch_tokens (ndarray): 表示生成的patch token的数组: N, C。
         """
        img = self.tfms(image)
        x = img[None,:]
        tokens = self.model.forward_features(x)[0] # N - 1, C
        tokens = nn.functional.normalize(tokens, dim=-1, p=2).numpy()
        cls_token = tokens[0] # C
        patch_tokens = tokens[1:] # N - 1, C
        return cls_token, patch_tokens
        
    def __call__(self, x):
        return self.model.forward_features(x)
    
default_shape = (224,224)
embedding = PatchEmbedding('weights/mmc.pth')


def classify(query_image, support_image):
    # Your classification code here
    q_cls = embedding.get_representation(query_image)[0]
    s_cls = embedding.get_representation(support_image)[0]
    
    sim = (q_cls*s_cls).sum()*100
    return f"{sim:.2f}"

def segment(threshold, input):
    # Your segmentation code here
    image = input['image']
    mask = input['mask']
    
    patch_tokens = embedding.get_representation(image)[1]
    select = (cv2.resize(mask[:,:,0],(14,14))>0).flatten()
    q_pat = patch_tokens[select].mean(0) # C
    sim = patch_tokens @ q_pat[:,None] # N,1
    
    mask = (sim.reshape(14,14) > threshold).astype('float')
    mask = cv2.resize(mask,(224,224))
    ans = image*mask[:,:,None]
    return ans.astype('uint8')

classification_tab = gr.Interface(
    fn=classify,
    inputs=[
        # gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8),
        gr.inputs.Image(label="Query Image",shape=default_shape),
        gr.inputs.Image(label="Support Image",shape=default_shape)
    ],
    outputs=gr.outputs.Textbox(label="Prediction"),
    title="Classification"
)

segmentation_tab = gr.Interface(
    fn=segment,
    inputs=[
        gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8),
        gr.inputs.Image(label="Input Image",tool="sketch",shape=default_shape)
    ],
    outputs=gr.outputs.Image('numpy',label='Segmentation'),
    title="Segmentation"
)

interface = gr.TabbedInterface(
    [classification_tab, segmentation_tab],
    ["Classification", "Segmentation"]
    # layout="horizontal"
)

interface.launch()