File size: 1,114 Bytes
e4bd7f9
 
 
 
 
 
 
 
 
 
192e5fb
e4bd7f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192e5fb
e4bd7f9
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
from torchvision.transforms import transforms

from ram.models import ram


class TaggingModule(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        import gc
        self.device = device
        image_size = 384
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        # load RAM Model
        self.ram = ram(
            pretrained='checkpoints/ram_swin_large_14m.pth',
            image_size=image_size,
            vit='swin_l'
        ).eval().to(device)
        print('==> Tagging Module Loaded.')
        gc.collect()

    @torch.no_grad()
    def forward(self, original_image):
        print('==> Tagging...')
        img = self.transform(original_image).unsqueeze(0).to(self.device)
        tags, tags_chinese = self.ram.generate_tag(img)
        print('==> Tagging results: {}'.format(tags[0]))
        return [tag for tag in tags[0].split(' | ')]