""" * Tag2Text * Written by Xinyu Huang """ import argparse import random import numpy as np import torch import torchvision.transforms as transforms from models.tag2text import tag2text_caption from PIL import Image parser = argparse.ArgumentParser( description="Tag2Text inferece for tagging and captioning" ) parser.add_argument( "--image", metavar="DIR", help="path to dataset", default="images/1641173_2291260800.jpg", ) parser.add_argument( "--pretrained", metavar="DIR", help="path to pretrained model", default="pretrained/tag2text_swin_14m.pth", ) parser.add_argument( "--image-size", default=384, type=int, metavar="N", help="input image size (default: 448)", ) parser.add_argument( "--thre", default=0.68, type=float, metavar="N", help="threshold value" ) parser.add_argument( "--specified-tags", default="None", help="User input specified tags" ) def inference(image, model, input_tag="None"): with torch.no_grad(): caption, tag_predict = model.generate( image, tag_input=None, max_length=50, return_tag_predict=True ) if input_tag == "" or input_tag == "none" or input_tag == "None": return tag_predict[0], None, caption[0] # If user input specified tags: else: input_tag_list = [] input_tag_list.append(input_tag.replace(",", " | ")) with torch.no_grad(): caption, input_tag = model.generate( image, tag_input=input_tag_list, max_length=50, return_tag_predict=True ) return tag_predict[0], input_tag[0], caption[0] if __name__ == "__main__": args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) transform = transforms.Compose( [ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), normalize, ] ) # delete some tags that may disturb captioning # 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one" delete_tag_index = [127, 2961, 3351, 3265, 3338, 3355, 3359] #######load model model = tag2text_caption( pretrained=args.pretrained, image_size=args.image_size, vit="swin_b", delete_tag_index=delete_tag_index, ) model.threshold = args.thre # threshold for tagging model.eval() model = model.to(device) raw_image = Image.open(args.image).resize((args.image_size, args.image_size)) image = transform(raw_image).unsqueeze(0).to(device) res = inference(image, model, args.specified_tags) print("Model Identified Tags: ", res[0]) print("User Specified Tags: ", res[1]) print("Image Caption: ", res[2])