Spaces:
Runtime error
Runtime error
File size: 1,114 Bytes
e4bd7f9 192e5fb e4bd7f9 192e5fb e4bd7f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
import torch.nn as nn
from torchvision.transforms import transforms
from ram.models import ram
class TaggingModule(nn.Module):
def __init__(self, device='cpu'):
super().__init__()
import gc
self.device = device
image_size = 384
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# load RAM Model
self.ram = ram(
pretrained='checkpoints/ram_swin_large_14m.pth',
image_size=image_size,
vit='swin_l'
).eval().to(device)
print('==> Tagging Module Loaded.')
gc.collect()
@torch.no_grad()
def forward(self, original_image):
print('==> Tagging...')
img = self.transform(original_image).unsqueeze(0).to(self.device)
tags, tags_chinese = self.ram.generate_tag(img)
print('==> Tagging results: {}'.format(tags[0]))
return [tag for tag in tags[0].split(' | ')]
|