| from argparse import ArgumentParser |
| from pathlib import Path |
| from typing import Dict, List, Optional, TextIO, Tuple |
|
|
| import torch |
| from PIL import Image, UnidentifiedImageError |
| from torch import Tensor |
| from torch.nn import Module, Parameter |
| from torch.nn.functional import relu, sigmoid |
| from torch.utils.data import DataLoader, Dataset |
| from tqdm import tqdm |
| import torch.nn.functional as F |
| import os |
| import json |
|
|
| from ram import get_transform |
| from ram.models import ram_plus, ram, tag2text |
| from ram.utils import build_openset_llm_label_embedding, build_openset_label_embedding, get_mAP, get_PR |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| class _Dataset(Dataset): |
| def __init__(self, imglist, input_size): |
| self.imglist = imglist |
| self.transform = get_transform(input_size) |
|
|
| def __len__(self): |
| return len(self.imglist) |
|
|
| def __getitem__(self, index): |
| try: |
| img = Image.open(self.imglist[index]+".jpg") |
| except (OSError, FileNotFoundError, UnidentifiedImageError): |
| img = Image.new('RGB', (10, 10), 0) |
| print("Error loading image:", self.imglist[index]) |
| return self.transform(img) |
|
|
|
|
| def parse_args(): |
| parser = ArgumentParser() |
| |
| parser.add_argument("--model-type", |
| type=str, |
| choices=("ram_plus", "ram", "tag2text"), |
| required=True) |
| parser.add_argument("--checkpoint", |
| type=str, |
| required=True) |
| parser.add_argument("--backbone", |
| type=str, |
| choices=("swin_l", "swin_b"), |
| default=None, |
| help="If `None`, will judge from `--model-type`") |
| parser.add_argument("--open-set", |
| action="store_true", |
| help=( |
| "Treat all categories in the taglist file as " |
| "unseen and perform open-set classification. Only " |
| "works with RAM." |
| )) |
| |
| parser.add_argument("--dataset", |
| type=str, |
| choices=( |
| "openimages_common_214", |
| "openimages_rare_200" |
| ), |
| required=True) |
| parser.add_argument("--input-size", |
| type=int, |
| default=384) |
| |
| group = parser.add_mutually_exclusive_group() |
| group.add_argument("--threshold", |
| type=float, |
| default=None, |
| help=( |
| "Use custom threshold for all classes. Mutually " |
| "exclusive with `--threshold-file`. If both " |
| "`--threshold` and `--threshold-file` is `None`, " |
| "will use a default threshold setting." |
| )) |
| group.add_argument("--threshold-file", |
| type=str, |
| default=None, |
| help=( |
| "Use custom class-wise thresholds by providing a " |
| "text file. Each line is a float-type threshold, " |
| "following the order of the tags in taglist file. " |
| "See `ram/data/ram_tag_list_threshold.txt` as an " |
| "example. Mutually exclusive with `--threshold`. " |
| "If both `--threshold` and `--threshold-file` is " |
| "`None`, will use default threshold setting." |
| )) |
| |
| parser.add_argument("--output-dir", type=str, default="./outputs") |
| parser.add_argument("--batch-size", type=int, default=128) |
| parser.add_argument("--num-workers", type=int, default=4) |
|
|
| args = parser.parse_args() |
|
|
| |
| args.model_type = args.model_type.lower() |
|
|
| assert not (args.model_type == "tag2text" and args.open_set) |
|
|
| if args.backbone is None: |
| args.backbone = "swin_l" if args.model_type == "ram_plus" or args.model_type == "ram" else "swin_b" |
|
|
| return args |
|
|
|
|
| def load_dataset( |
| dataset: str, |
| model_type: str, |
| input_size: int, |
| batch_size: int, |
| num_workers: int |
| ) -> Tuple[DataLoader, Dict]: |
| dataset_root = str(Path(__file__).resolve().parent / "datasets" / dataset) |
| img_root = dataset_root + "/imgs" |
| |
| |
| |
| if model_type == "ram_plus" or model_type == "ram": |
| tag_file = dataset_root + f"/{dataset}_ram_taglist.txt" |
| annot_file = dataset_root + f"/{dataset}_ram_annots.txt" |
| else: |
| tag_file = dataset_root + f"/{dataset}_tag2text_tagidlist.txt" |
| annot_file = dataset_root + f"/{dataset}_{model_type}_idannots.txt" |
|
|
| with open(tag_file, "r", encoding="utf-8") as f: |
| taglist = [line.strip() for line in f] |
|
|
| with open(annot_file, "r", encoding="utf-8") as f: |
| imglist = [img_root + "/" + line.strip().split(",")[0] for line in f] |
|
|
| loader = DataLoader( |
| dataset=_Dataset(imglist,input_size), |
| shuffle=False, |
| drop_last=False, |
| pin_memory=True, |
| batch_size=batch_size, |
| num_workers=num_workers |
| ) |
| |
| open_tag_des = dataset_root + f"/{dataset}_llm_tag_descriptions.json" |
| if os.path.exists(open_tag_des): |
| with open(open_tag_des, 'rb') as fo: |
| tag_des = json.load(fo) |
|
|
| else: |
| tag_des = None |
| info = { |
| "taglist": taglist, |
| "imglist": imglist, |
| "annot_file": annot_file, |
| "img_root": img_root, |
| "tag_des": tag_des |
| } |
|
|
| return loader, info |
|
|
|
|
| def get_class_idxs( |
| model_type: str, |
| open_set: bool, |
| taglist: List[str] |
| ) -> Optional[List[int]]: |
| """Get indices of required categories in the label system.""" |
| if model_type == "ram_plus" or model_type == "ram": |
| if not open_set: |
| model_taglist_file = "ram/data/ram_tag_list.txt" |
| with open(model_taglist_file, "r", encoding="utf-8") as f: |
| model_taglist = [line.strip() for line in f] |
| return [model_taglist.index(tag) for tag in taglist] |
| else: |
| return None |
| else: |
| |
| return [int(tag) for tag in taglist] |
|
|
|
|
| def load_thresholds( |
| threshold: Optional[float], |
| threshold_file: Optional[str], |
| model_type: str, |
| open_set: bool, |
| class_idxs: List[int], |
| num_classes: int, |
| ) -> List[float]: |
| """Decide what threshold(s) to use.""" |
| if not threshold_file and not threshold: |
| if model_type == "ram_plus" or model_type == "ram": |
| if not open_set: |
| ram_threshold_file = "ram/data/ram_tag_list_threshold.txt" |
| with open(ram_threshold_file, "r", encoding="utf-8") as f: |
| idx2thre = { |
| idx: float(line.strip()) for idx, line in enumerate(f) |
| } |
| return [idx2thre[idx] for idx in class_idxs] |
| else: |
| return [0.5] * num_classes |
| else: |
| return [0.68] * num_classes |
| elif threshold_file: |
| with open(threshold_file, "r", encoding="utf-8") as f: |
| thresholds = [float(line.strip()) for line in f] |
| assert len(thresholds) == num_classes |
| return thresholds |
| else: |
| return [threshold] * num_classes |
|
|
|
|
| def gen_pred_file( |
| imglist: List[str], |
| tags: List[List[str]], |
| img_root: str, |
| pred_file: str |
| ) -> None: |
| """Generate text file of tag prediction results.""" |
| with open(pred_file, "w", encoding="utf-8") as f: |
| for image, tag in zip(imglist, tags): |
| |
| s = str(Path(image).relative_to(img_root)) |
| if tag: |
| s = s + "," + ",".join(tag) |
| f.write(s + "\n") |
|
|
| def load_ram_plus( |
| backbone: str, |
| checkpoint: str, |
| input_size: int, |
| taglist: List[str], |
| tag_des: List[str], |
| open_set: bool, |
| class_idxs: List[int], |
| ) -> Module: |
| model = ram_plus(pretrained=checkpoint, image_size=input_size, vit=backbone) |
| |
| if open_set: |
| print("Building tag embeddings ...") |
| label_embed, _ = build_openset_llm_label_embedding(tag_des) |
| model.label_embed = Parameter(label_embed.float()) |
| model.num_class = len(tag_des) |
| else: |
| model.label_embed = Parameter(model.label_embed.data.reshape(model.num_class,51,512)[class_idxs, :, :].reshape(len(class_idxs)*51, 512)) |
| model.num_class = len(class_idxs) |
| return model.to(device).eval() |
|
|
|
|
| def load_ram( |
| backbone: str, |
| checkpoint: str, |
| input_size: int, |
| taglist: List[str], |
| open_set: bool, |
| class_idxs: List[int], |
| ) -> Module: |
| model = ram(pretrained=checkpoint, image_size=input_size, vit=backbone) |
| |
| if open_set: |
| print("Building tag embeddings ...") |
| label_embed, _ = build_openset_label_embedding(taglist) |
| model.label_embed = Parameter(label_embed.float()) |
| else: |
| model.label_embed = Parameter(model.label_embed[class_idxs, :]) |
| return model.to(device).eval() |
|
|
|
|
| def load_tag2text( |
| backbone: str, |
| checkpoint: str, |
| input_size: int |
| ) -> Module: |
| model = tag2text( |
| pretrained=checkpoint, |
| image_size=input_size, |
| vit=backbone |
| ) |
| return model.to(device).eval() |
|
|
| @torch.no_grad() |
| def forward_ram_plus(model: Module, imgs: Tensor) -> Tensor: |
| image_embeds = model.image_proj(model.visual_encoder(imgs.to(device))) |
| image_atts = torch.ones( |
| image_embeds.size()[:-1], dtype=torch.long).to(device) |
|
|
| image_cls_embeds = image_embeds[:, 0, :] |
| image_spatial_embeds = image_embeds[:, 1:, :] |
|
|
| bs = image_spatial_embeds.shape[0] |
|
|
| des_per_class = int(model.label_embed.shape[0] / model.num_class) |
|
|
| image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True) |
| reweight_scale = model.reweight_scale.exp() |
| logits_per_image = (reweight_scale * image_cls_embeds @ model.label_embed.t()) |
| logits_per_image = logits_per_image.view(bs, -1,des_per_class) |
|
|
| weight_normalized = F.softmax(logits_per_image, dim=2) |
| label_embed_reweight = torch.empty(bs, model.num_class, 512).cuda() |
| weight_normalized = F.softmax(logits_per_image, dim=2) |
| label_embed_reweight = torch.empty(bs, model.num_class, 512).cuda() |
| for i in range(bs): |
| reshaped_value = model.label_embed.view(-1, des_per_class, 512) |
| product = weight_normalized[i].unsqueeze(-1) * reshaped_value |
| label_embed_reweight[i] = product.sum(dim=1) |
|
|
| label_embed = relu(model.wordvec_proj(label_embed_reweight)) |
|
|
| tagging_embed, _ = model.tagging_head( |
| encoder_embeds=label_embed, |
| encoder_hidden_states=image_embeds, |
| encoder_attention_mask=image_atts, |
| return_dict=False, |
| mode='tagging', |
| ) |
| return sigmoid(model.fc(tagging_embed).squeeze(-1)) |
|
|
| @torch.no_grad() |
| def forward_ram(model: Module, imgs: Tensor) -> Tensor: |
| image_embeds = model.image_proj(model.visual_encoder(imgs.to(device))) |
| image_atts = torch.ones( |
| image_embeds.size()[:-1], dtype=torch.long).to(device) |
| label_embed = relu(model.wordvec_proj(model.label_embed)).unsqueeze(0)\ |
| .repeat(imgs.shape[0], 1, 1) |
| tagging_embed, _ = model.tagging_head( |
| encoder_embeds=label_embed, |
| encoder_hidden_states=image_embeds, |
| encoder_attention_mask=image_atts, |
| return_dict=False, |
| mode='tagging', |
| ) |
| return sigmoid(model.fc(tagging_embed).squeeze(-1)) |
|
|
|
|
| @torch.no_grad() |
| def forward_tag2text( |
| model: Module, |
| class_idxs: List[int], |
| imgs: Tensor |
| ) -> Tensor: |
| image_embeds = model.visual_encoder(imgs.to(device)) |
| image_atts = torch.ones( |
| image_embeds.size()[:-1], dtype=torch.long).to(device) |
| label_embed = model.label_embed.weight.unsqueeze(0)\ |
| .repeat(imgs.shape[0], 1, 1) |
| tagging_embed, _ = model.tagging_head( |
| encoder_embeds=label_embed, |
| encoder_hidden_states=image_embeds, |
| encoder_attention_mask=image_atts, |
| return_dict=False, |
| mode='tagging', |
| ) |
| return sigmoid(model.fc(tagging_embed))[:, class_idxs] |
|
|
|
|
| def print_write(f: TextIO, s: str): |
| print(s) |
| f.write(s + "\n") |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
|
|
| |
| output_dir = args.output_dir |
| Path(output_dir).mkdir(parents=True, exist_ok=True) |
| pred_file, pr_file, ap_file, summary_file, logit_file = [ |
| output_dir + "/" + name for name in |
| ("pred.txt", "pr.txt", "ap.txt", "summary.txt", "logits.pth") |
| ] |
| with open(summary_file, "w", encoding="utf-8") as f: |
| print_write(f, "****************") |
| for key in ( |
| "model_type", "backbone", "checkpoint", "open_set", |
| "dataset", "input_size", |
| "threshold", "threshold_file", |
| "output_dir", "batch_size", "num_workers" |
| ): |
| print_write(f, f"{key}: {getattr(args, key)}") |
| print_write(f, "****************") |
|
|
| |
| loader, info = load_dataset( |
| dataset=args.dataset, |
| model_type=args.model_type, |
| input_size=args.input_size, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers |
| ) |
| taglist, imglist, annot_file, img_root, tag_des = \ |
| info["taglist"], info["imglist"], info["annot_file"], info["img_root"], info["tag_des"] |
|
|
| |
| class_idxs = get_class_idxs( |
| model_type=args.model_type, |
| open_set=args.open_set, |
| taglist=taglist |
| ) |
|
|
| |
| thresholds = load_thresholds( |
| threshold=args.threshold, |
| threshold_file=args.threshold_file, |
| model_type=args.model_type, |
| open_set=args.open_set, |
| class_idxs=class_idxs, |
| num_classes=len(taglist) |
| ) |
|
|
| |
| if Path(logit_file).is_file(): |
|
|
| logits = torch.load(logit_file) |
|
|
| else: |
| |
| if args.model_type == "ram_plus": |
| model = load_ram_plus( |
| backbone=args.backbone, |
| checkpoint=args.checkpoint, |
| input_size=args.input_size, |
| taglist=taglist, |
| tag_des = tag_des, |
| open_set=args.open_set, |
| class_idxs=class_idxs |
| ) |
| elif args.model_type == "ram": |
| model = load_ram( |
| backbone=args.backbone, |
| checkpoint=args.checkpoint, |
| input_size=args.input_size, |
| taglist=taglist, |
| open_set=args.open_set, |
| class_idxs=class_idxs |
| ) |
| elif args.model_type == "tag2text": |
| model = load_tag2text( |
| backbone=args.backbone, |
| checkpoint=args.checkpoint, |
| input_size=args.input_size |
| ) |
|
|
| |
| logits = torch.empty(len(imglist), len(taglist)) |
| pos = 0 |
| for imgs in tqdm(loader, desc="inference"): |
| if args.model_type == "ram_plus": |
| out = forward_ram_plus(model, imgs) |
| elif args.model_type == "ram": |
| out = forward_ram(model, imgs) |
| else: |
| out = forward_tag2text(model, class_idxs, imgs) |
| bs = imgs.shape[0] |
| logits[pos:pos+bs, :] = out.cpu() |
| pos += bs |
|
|
| |
| torch.save(logits, logit_file) |
|
|
| |
| pred_tags = [] |
| for scores in logits.tolist(): |
| pred_tags.append([ |
| taglist[i] for i, s in enumerate(scores) if s >= thresholds[i] |
| ]) |
|
|
| |
| gen_pred_file(imglist, pred_tags, img_root, pred_file) |
|
|
| |
| mAP, APs = get_mAP(logits.numpy(), annot_file, taglist) |
| CP, CR, Ps, Rs = get_PR(pred_file, annot_file, taglist) |
|
|
| with open(ap_file, "w", encoding="utf-8") as f: |
| f.write("Tag,AP\n") |
| for tag, AP in zip(taglist, APs): |
| f.write(f"{tag},{AP*100.0:.2f}\n") |
|
|
| with open(pr_file, "w", encoding="utf-8") as f: |
| f.write("Tag,Precision,Recall\n") |
| for tag, P, R in zip(taglist, Ps, Rs): |
| f.write(f"{tag},{P*100.0:.2f},{R*100.0:.2f}\n") |
|
|
| with open(summary_file, "w", encoding="utf-8") as f: |
| print_write(f, f"mAP: {mAP*100.0}") |
| print_write(f, f"CP: {CP*100.0}") |
| print_write(f, f"CR: {CR*100.0}") |
|
|