diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..e7147d9df2efd3675b8c80f5675ed35089916467 100644 --- a/.gitattributes +++ b/.gitattributes @@ -17,10 +17,6 @@ *.ot filter=lfs diff=lfs merge=lfs -text *.parquet filter=lfs diff=lfs merge=lfs -text *.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text *.rar filter=lfs diff=lfs merge=lfs -text *.safetensors filter=lfs diff=lfs merge=lfs -text saved_model/**/* filter=lfs diff=lfs merge=lfs -text @@ -33,3 +29,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0c77ed152ac2bc2d99a8b20f157520d3603ccf74 --- /dev/null +++ b/.gitignore @@ -0,0 +1,125 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/en/_build/ +docs/zh_cn/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +data/ +data +.vscode +.idea/ +.DS_Store + +# custom +*.pkl +*.pkl.json +*.log.json +docs/modelzoo_statistics.md +mmdet/.mim +work_dirs/ + +# Pytorch +*.py~ +*.sh~ + +# remove tmp folder +tmp/ diff --git a/app/configs/m2_convl.py b/app/configs/m2_convl.py new file mode 100644 index 0000000000000000000000000000000000000000..1a05a8035172360676ee2a4c80bc8d414448222d --- /dev/null +++ b/app/configs/m2_convl.py @@ -0,0 +1,152 @@ +from torch.nn import GroupNorm, ReLU + +from mmdet.models import MSDeformAttnPixelDecoder, CrossEntropyLoss, DiceLoss, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, ClassificationCost, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +from seg.models.detectors import Mask2formerVideo +from seg.models.fusion_head import OMGFusionHead +from seg.models.heads import Mask2FormerVideoHead +from seg.models.backbones import OpenCLIPBackbone + +num_things_classes = 80 +num_stuff_classes = 53 + +ov_model_name = 'convnext_large_d_320' +ov_datasets_name = 'CocoPanopticOVDataset' +model = dict( + type=Mask2formerVideo, + data_preprocessor=None, # to fill + backbone=dict( + type=OpenCLIPBackbone, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + init_cfg=dict( + type='Pretrained', + checkpoint='./models/omg_seg_convl.pth', + prefix='panoptic_head.' + ), + type=Mask2FormerVideoHead, + sphere_cls=True, + ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}', + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=None # [1.0] * num_classes + [0.1] + ), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean' + ) + ), + panoptic_fusion_head=dict( + type=OMGFusionHead, + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + loss_panoptic=None, + init_cfg=None + ), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + dict(type=ClassificationCost, weight=2.0), + dict( + type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True, + object_mask_thr=0., + ), + init_cfg=None +) diff --git a/assets/000000000139.jpg b/assets/000000000139.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fcc0ee640f200f4557b4b77fd2fd010cf1d1bc2c --- /dev/null +++ b/assets/000000000139.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ffe0f0cec3b2e27aab1967229cdf0a0d7751dcdd5800322f0b8ac0dffb3b8a8d +size 161811 diff --git a/assets/000000000285.jpg b/assets/000000000285.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d6fac3296dfd30328248b229bcea70f1d6691119 --- /dev/null +++ b/assets/000000000285.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3a2974ce3686332609124c70e3e6a2e3aca43fccf1cd1bd7c5c03820977f57d +size 335861 diff --git a/assets/000000000632.jpg b/assets/000000000632.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8a13832dd7305067b5d33cfa7ce6d695c3972347 --- /dev/null +++ b/assets/000000000632.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a4cd7f45ac1ce27eaafb254b23af7c0b18a064be08870ceaaf03b2147f2ce550 +size 155667 diff --git a/assets/000000000724.jpg b/assets/000000000724.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5c43ddad3428ab43630dea7797807454e64f6d80 --- /dev/null +++ b/assets/000000000724.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c0e559c75d3969c8e3e297b61f61063f78045c9d4802b526ba616361f3823fd +size 130107 diff --git a/ext/cityscapes_scripts/createPanopticImgs.py b/ext/cityscapes_scripts/createPanopticImgs.py new file mode 100644 index 0000000000000000000000000000000000000000..6e4eaa3478f557d834f2948d73e37bc003b564cb --- /dev/null +++ b/ext/cityscapes_scripts/createPanopticImgs.py @@ -0,0 +1,194 @@ +#!/usr/bin/python +# +# Converts the *instanceIds.png annotations of the Cityscapes dataset +# to COCO-style panoptic segmentation format (http://cocodataset.org/#format-data). +# The convertion is working for 'fine' set of the annotations. +# +# By default with this tool uses IDs specified in labels.py. You can use flag +# --use-train-id to get train ids for categories. 'ignoreInEval' categories are +# removed during the conversion. +# +# In panoptic segmentation format image_id is used to match predictions and ground truth. +# For cityscapes image_id has form _123456_123456 and corresponds to the prefix +# of cityscapes image files. +# + +# python imports +from __future__ import print_function, absolute_import, division, unicode_literals +import os +import glob +import sys +import argparse +import json +import numpy as np + +# Image processing +from PIL import Image + +# cityscapes imports +from ext.cityscapes_scripts.helpers.csHelpers import printError +from ext.cityscapes_scripts.helpers.labels import id2label, labels + + +import mmengine + + +# The main method +def convert2panoptic(cityscapesPath=None, outputFolder=None, useTrainId=False, setNames=["val", "train", "test"]): + # Where to look for Cityscapes + if cityscapesPath is None: + if 'CITYSCAPES_DATASET' in os.environ: + cityscapesPath = os.environ['CITYSCAPES_DATASET'] + else: + cityscapesPath = 'data/cityscapes' + cityscapesPath = os.path.join(cityscapesPath, "gtFine") + + if outputFolder is None: + outputFolder = cityscapesPath.replace('gtFine', "annotations") + + mmengine.mkdir_or_exist(outputFolder) + + categories = [] + for label in labels: + if label.ignoreInEval: + continue + categories.append({'id': int(label.trainId) if useTrainId else int(label.id), + 'name': label.name, + 'color': label.color, + 'supercategory': label.category, + 'isthing': 1 if label.hasInstances else 0}) + + categories = sorted(categories, key=lambda x:x['id']) + + for setName in setNames: + # how to search for all ground truth + searchFine = os.path.join(cityscapesPath, setName, "*", "*_instanceIds.png") + # search files + filesFine = glob.glob(searchFine) + filesFine.sort() + + files = filesFine + # quit if we did not find anything + if not files: + printError( + "Did not find any files for {} set using matching pattern {}. Please consult the README.".format(setName, searchFine) + ) + # a bit verbose + print("Converting {} annotation files for {} set.".format(len(files), setName)) + + trainIfSuffix = "_trainId" if useTrainId else "" + outputBaseFile = "cityscapes_panoptic_{}{}".format(setName, trainIfSuffix) + outFile = os.path.join(outputFolder, "{}.json".format(outputBaseFile)) + print("Json file with the annotations in panoptic format will be saved in {}".format(outFile)) + panopticFolder = os.path.join(outputFolder, outputBaseFile) + if not os.path.isdir(panopticFolder): + print("Creating folder {} for panoptic segmentation PNGs".format(panopticFolder)) + os.mkdir(panopticFolder) + print("Corresponding segmentations in .png format will be saved in {}".format(panopticFolder)) + + images = [] + annotations = [] + for progress, f in enumerate(files): + + originalFormat = np.array(Image.open(f)) + + fileName = os.path.basename(f) + location = fileName.split('_')[0] + imageId = fileName.replace("_gtFine_instanceIds.png", "") + fileName = os.path.join(location, fileName) + inputFileName = fileName.replace("_gtFine_instanceIds.png", "_leftImg8bit.png") + outputFileName = fileName.replace("_gtFine_instanceIds.png", "_panoptic.png") + # image entry, id for image is its filename without extension + images.append({"id": imageId, + "width": int(originalFormat.shape[1]), + "height": int(originalFormat.shape[0]), + "file_name": inputFileName}) + + pan_format = np.zeros( + (originalFormat.shape[0], originalFormat.shape[1], 3), dtype=np.uint8 + ) + + segmentIds = np.unique(originalFormat) + segmInfo = [] + for segmentId in segmentIds: + if segmentId < 1000: + semanticId = segmentId + isCrowd = 1 + else: + semanticId = segmentId // 1000 + isCrowd = 0 + labelInfo = id2label[semanticId] + categoryId = labelInfo.trainId if useTrainId else labelInfo.id + if labelInfo.ignoreInEval: + continue + if not labelInfo.hasInstances: + isCrowd = 0 + + mask = originalFormat == segmentId + color = [segmentId % 256, segmentId // 256, segmentId // 256 // 256] + pan_format[mask] = color + + area = np.sum(mask) # segment area computation + + # bbox computation for a segment + hor = np.sum(mask, axis=0) + hor_idx = np.nonzero(hor)[0] + x = hor_idx[0] + width = hor_idx[-1] - x + 1 + vert = np.sum(mask, axis=1) + vert_idx = np.nonzero(vert)[0] + y = vert_idx[0] + height = vert_idx[-1] - y + 1 + bbox = [int(x), int(y), int(width), int(height)] + + segmInfo.append({"id": int(segmentId), + "category_id": int(categoryId), + "area": int(area), + "bbox": bbox, + "iscrowd": isCrowd}) + + annotations.append({'image_id': imageId, + 'file_name': outputFileName, + "segments_info": segmInfo}) + + mmengine.mkdir_or_exist(os.path.dirname(os.path.join(panopticFolder, outputFileName))) + Image.fromarray(pan_format).save(os.path.join(panopticFolder, outputFileName)) + + print("\rProgress: {:>3.2f} %".format((progress + 1) * 100 / len(files)), end=' ') + sys.stdout.flush() + + print("\nSaving the json file {}".format(outFile)) + d = {'images': images, + 'annotations': annotations, + 'categories': categories} + with open(outFile, 'w') as f: + json.dump(d, f, sort_keys=True, indent=4) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--dataset-folder", + dest="cityscapesPath", + help="path to the Cityscapes dataset 'gtFine' folder", + default=None, + type=str) + parser.add_argument("--output-folder", + dest="outputFolder", + help="path to the output folder.", + default=None, + type=str) + parser.add_argument("--use-train-id", default=True,action="store_true", dest="useTrainId") + parser.add_argument("--set-names", + dest="setNames", + help="set names to which apply the function to", + nargs='+', + default=["val", "train"], + type=str) + args = parser.parse_args() + + convert2panoptic(args.cityscapesPath, args.outputFolder, args.useTrainId, args.setNames) + + +# call the main +if __name__ == "__main__": + main() diff --git a/ext/cityscapes_scripts/helpers/__init__.py b/ext/cityscapes_scripts/helpers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc63beba4a21ad527d0821d228c5192f73fe1183 --- /dev/null +++ b/ext/cityscapes_scripts/helpers/__init__.py @@ -0,0 +1 @@ +# empty \ No newline at end of file diff --git a/ext/cityscapes_scripts/helpers/annotation.py b/ext/cityscapes_scripts/helpers/annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..67326bfd0f33894f186320981b917b994916ae7a --- /dev/null +++ b/ext/cityscapes_scripts/helpers/annotation.py @@ -0,0 +1,441 @@ +#!/usr/bin/python +# +# Classes to store, read, and write annotations +# + +from __future__ import print_function, absolute_import, division +import os +import json +import numpy as np +from collections import namedtuple + +# get current date and time +import datetime +import locale + +from abc import ABCMeta, abstractmethod +from .box3dImageTransform import Camera + +# A point in a polygon +Point = namedtuple('Point', ['x', 'y']) + + +class CsObjectType(): + """Type of an object""" + POLY = 1 # polygon + BBOX2D = 2 # bounding box + BBOX3D = 3 # 3d bounding box + IGNORE2D = 4 # 2d ignore region + + +class CsObject: + """Abstract base class for annotation objects""" + __metaclass__ = ABCMeta + + def __init__(self, objType): + self.objectType = objType + # the label + self.label = "" + + # If deleted or not + self.deleted = 0 + # If verified or not + self.verified = 0 + # The date string + self.date = "" + # The username + self.user = "" + # Draw the object + # Not read from or written to JSON + # Set to False if deleted object + # Might be set to False by the application for other reasons + self.draw = True + + @abstractmethod + def __str__(self): pass + + @abstractmethod + def fromJsonText(self, jsonText, objId=-1): pass + + @abstractmethod + def toJsonText(self): pass + + def updateDate(self): + try: + locale.setlocale(locale.LC_ALL, 'en_US.utf8') + except locale.Error: + locale.setlocale(locale.LC_ALL, 'en_US') + except locale.Error: + locale.setlocale(locale.LC_ALL, 'us_us.utf8') + except locale.Error: + locale.setlocale(locale.LC_ALL, 'us_us') + except Exception: + pass + self.date = datetime.datetime.now().strftime("%d-%b-%Y %H:%M:%S") + + # Mark the object as deleted + def delete(self): + self.deleted = 1 + self.draw = False + + +class CsPoly(CsObject): + """Class that contains the information of a single annotated object as polygon""" + + # Constructor + def __init__(self): + CsObject.__init__(self, CsObjectType.POLY) + # the polygon as list of points + self.polygon = [] + # the object ID + self.id = -1 + + def __str__(self): + polyText = "" + if self.polygon: + if len(self.polygon) <= 4: + for p in self.polygon: + polyText += '({},{}) '.format(p.x, p.y) + else: + polyText += '({},{}) ({},{}) ... ({},{}) ({},{})'.format( + self.polygon[0].x, self.polygon[0].y, + self.polygon[1].x, self.polygon[1].y, + self.polygon[-2].x, self.polygon[-2].y, + self.polygon[-1].x, self.polygon[-1].y) + else: + polyText = "none" + text = "Object: {} - {}".format(self.label, polyText) + return text + + def fromJsonText(self, jsonText, objId=-1): + self.id = objId + self.label = str(jsonText['label']) + self.polygon = [Point(p[0], p[1]) for p in jsonText['polygon']] + if 'deleted' in jsonText.keys(): + self.deleted = jsonText['deleted'] + else: + self.deleted = 0 + if 'verified' in jsonText.keys(): + self.verified = jsonText['verified'] + else: + self.verified = 1 + if 'user' in jsonText.keys(): + self.user = jsonText['user'] + else: + self.user = '' + if 'date' in jsonText.keys(): + self.date = jsonText['date'] + else: + self.date = '' + if self.deleted == 1: + self.draw = False + else: + self.draw = True + + def toJsonText(self): + objDict = {} + objDict['label'] = self.label + objDict['id'] = self.id + objDict['deleted'] = self.deleted + objDict['verified'] = self.verified + objDict['user'] = self.user + objDict['date'] = self.date + objDict['polygon'] = [] + for pt in self.polygon: + objDict['polygon'].append([pt.x, pt.y]) + + return objDict + + +class CsBbox2d(CsObject): + """Class that contains the information of a single annotated object as bounding box""" + + # Constructor + def __init__(self): + CsObject.__init__(self, CsObjectType.BBOX2D) + # the polygon as list of points + self.bbox_amodal_xywh = [] + self.bbox_modal_xywh = [] + + # the ID of the corresponding object + self.instanceId = -1 + # the label of the corresponding object + self.label = "" + + def __str__(self): + bboxAmodalText = "" + bboxAmodalText += '[(x1: {}, y1: {}), (w: {}, h: {})]'.format( + self.bbox_amodal_xywh[0], self.bbox_amodal_xywh[1], self.bbox_amodal_xywh[2], self.bbox_amodal_xywh[3]) + + bboxModalText = "" + bboxModalText += '[(x1: {}, y1: {}), (w: {}, h: {})]'.format( + self.bbox_modal_xywh[0], self.bbox_modal_xywh[1], self.bbox_modal_xywh[2], self.bbox_modal_xywh[3]) + + text = "Object: {}\n - Amodal {}\n - Modal {}".format( + self.label, bboxAmodalText, bboxModalText) + return text + + def setAmodalBox(self, bbox_amodal): + # sets the amodal box if required + self.bbox_amodal_xywh = [ + bbox_amodal[0], + bbox_amodal[1], + bbox_amodal[2] - bbox_amodal[0], + bbox_amodal[3] - bbox_amodal[1] + ] + + # access 2d boxes in [xmin, ymin, xmax, ymax] format + @property + def bbox_amodal(self): + """Returns the 2d box as [xmin, ymin, xmax, ymax]""" + return [ + self.bbox_amodal_xywh[0], + self.bbox_amodal_xywh[1], + self.bbox_amodal_xywh[0] + self.bbox_amodal_xywh[2], + self.bbox_amodal_xywh[1] + self.bbox_amodal_xywh[3] + ] + + @property + def bbox_modal(self): + """Returns the 2d box as [xmin, ymin, xmax, ymax]""" + return [ + self.bbox_modal_xywh[0], + self.bbox_modal_xywh[1], + self.bbox_modal_xywh[0] + self.bbox_modal_xywh[2], + self.bbox_modal_xywh[1] + self.bbox_modal_xywh[3] + ] + + def fromJsonText(self, jsonText, objId=-1): + # try to load from cityperson format + if 'bbox' in jsonText.keys() and 'bboxVis' in jsonText.keys(): + self.bbox_amodal_xywh = jsonText['bbox'] + self.bbox_modal_xywh = jsonText['bboxVis'] + # both modal and amodal boxes are provided + elif "modal" in jsonText.keys() and "amodal" in jsonText.keys(): + self.bbox_amodal_xywh = jsonText['amodal'] + self.bbox_modal_xywh = jsonText['modal'] + # only amodal boxes are provided + else: + self.bbox_modal_xywh = jsonText['amodal'] + self.bbox_amodal_xywh = jsonText['amodal'] + + # load label and instanceId if available + if 'label' in jsonText.keys() and 'instanceId' in jsonText.keys(): + self.label = str(jsonText['label']) + self.instanceId = jsonText['instanceId'] + + def toJsonText(self): + objDict = {} + objDict['label'] = self.label + objDict['instanceId'] = self.instanceId + objDict['modal'] = self.bbox_modal_xywh + objDict['amodal'] = self.bbox_amodal_xywh + + return objDict + + +class CsBbox3d(CsObject): + """Class that contains the information of a single annotated object as 3D bounding box""" + + # Constructor + def __init__(self): + CsObject.__init__(self, CsObjectType.BBOX3D) + + self.bbox_2d = None + + self.center = [] + self.dims = [] + self.rotation = [] + self.instanceId = -1 + self.label = "" + self.score = -1. + + def __str__(self): + bbox2dText = str(self.bbox_2d) + + bbox3dText = "" + bbox3dText += '\n - Center (x/y/z) [m]: {}/{}/{}'.format( + self.center[0], self.center[1], self.center[2]) + bbox3dText += '\n - Dimensions (l/w/h) [m]: {}/{}/{}'.format( + self.dims[0], self.dims[1], self.dims[2]) + bbox3dText += '\n - Rotation: {}/{}/{}/{}'.format( + self.rotation[0], self.rotation[1], self.rotation[2], self.rotation[3]) + + text = "Object: {}\n2D {}\n - 3D {}".format( + self.label, bbox2dText, bbox3dText) + return text + + def fromJsonText(self, jsonText, objId=-1): + # load 2D box + self.bbox_2d = CsBbox2d() + self.bbox_2d.fromJsonText(jsonText['2d']) + + self.center = jsonText['3d']['center'] + self.dims = jsonText['3d']['dimensions'] + self.rotation = jsonText['3d']['rotation'] + self.label = jsonText['label'] + self.score = jsonText['score'] + + if 'instanceId' in jsonText.keys(): + self.instanceId = jsonText['instanceId'] + + def toJsonText(self): + objDict = {} + objDict['label'] = self.label + objDict['instanceId'] = self.instanceId + objDict['2d']['amodal'] = self.bbox_2d.bbox_amodal_xywh + objDict['2d']['modal'] = self.bbox_2d.bbox_modal_xywh + objDict['3d']['center'] = self.center + objDict['3d']['dimensions'] = self.dims + objDict['3d']['rotation'] = self.rotation + + return objDict + + @property + def depth(self): + # returns the BEV depth + return np.sqrt(self.center[0]**2 + self.center[1]**2).astype(int) + + +class CsIgnore2d(CsObject): + """Class that contains the information of a single annotated 2d ignore region""" + + # Constructor + def __init__(self): + CsObject.__init__(self, CsObjectType.IGNORE2D) + + self.bbox_xywh = [] + self.label = "" + self.instanceId = -1 + + def __str__(self): + bbox2dText = "" + bbox2dText += 'Ignore Region: (x1: {}, y1: {}), (w: {}, h: {})'.format( + self.bbox_xywh[0], self.bbox_xywh[1], self.bbox_xywh[2], self.bbox_xywh[3]) + + return bbox2dText + + def fromJsonText(self, jsonText, objId=-1): + self.bbox_xywh = jsonText['2d'] + + if 'label' in jsonText.keys(): + self.label = jsonText['label'] + + if 'instanceId' in jsonText.keys(): + self.instanceId = jsonText['instanceId'] + + def toJsonText(self): + objDict = {} + objDict['label'] = self.label + objDict['instanceId'] = self.instanceId + objDict['2d'] = self.bbox_xywh + + return objDict + + @property + def bbox(self): + """Returns the 2d box as [xmin, ymin, xmax, ymax]""" + return [ + self.bbox_xywh[0], + self.bbox_xywh[1], + self.bbox_xywh[0] + self.bbox_xywh[2], + self.bbox_xywh[1] + self.bbox_xywh[3] + ] + + # Extend api to be compatible to bbox2d + @property + def bbox_amodal_xywh(self): + return self.bbox_xywh + + @property + def bbox_modal_xywh(self): + return self.bbox_xywh + + +class Annotation: + """The annotation of a whole image (doesn't support mixed annotations, i.e. combining CsPoly and CsBbox2d)""" + + # Constructor + def __init__(self, objType=CsObjectType.POLY): + # the width of that image and thus of the label image + self.imgWidth = 0 + # the height of that image and thus of the label image + self.imgHeight = 0 + # the list of objects + self.objects = [] + # the camera calibration + self.camera = None + assert objType in CsObjectType.__dict__.values() + self.objectType = objType + + def toJson(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + def fromJsonText(self, jsonText): + jsonDict = json.loads(jsonText) + self.imgWidth = int(jsonDict['imgWidth']) + self.imgHeight = int(jsonDict['imgHeight']) + self.objects = [] + # load objects + if self.objectType != CsObjectType.IGNORE2D: + for objId, objIn in enumerate(jsonDict['objects']): + if self.objectType == CsObjectType.POLY: + obj = CsPoly() + elif self.objectType == CsObjectType.BBOX2D: + obj = CsBbox2d() + elif self.objectType == CsObjectType.BBOX3D: + obj = CsBbox3d() + obj.fromJsonText(objIn, objId) + self.objects.append(obj) + + # load ignores + if 'ignore' in jsonDict.keys(): + for ignoreId, ignoreIn in enumerate(jsonDict['ignore']): + obj = CsIgnore2d() + obj.fromJsonText(ignoreIn, ignoreId) + self.objects.append(obj) + + # load camera calibration + if 'sensor' in jsonDict.keys(): + self.camera = Camera(fx=jsonDict['sensor']['fx'], + fy=jsonDict['sensor']['fy'], + u0=jsonDict['sensor']['u0'], + v0=jsonDict['sensor']['v0'], + sensor_T_ISO_8855=jsonDict['sensor']['sensor_T_ISO_8855']) + + def toJsonText(self): + jsonDict = {} + jsonDict['imgWidth'] = self.imgWidth + jsonDict['imgHeight'] = self.imgHeight + jsonDict['objects'] = [] + for obj in self.objects: + objDict = obj.toJsonText() + jsonDict['objects'].append(objDict) + + return jsonDict + + # Read a json formatted polygon file and return the annotation + def fromJsonFile(self, jsonFile): + if not os.path.isfile(jsonFile): + print('Given json file not found: {}'.format(jsonFile)) + return + with open(jsonFile, 'r') as f: + jsonText = f.read() + self.fromJsonText(jsonText) + + def toJsonFile(self, jsonFile): + with open(jsonFile, 'w') as f: + f.write(self.toJson()) + + +# a dummy example +if __name__ == "__main__": + obj = CsPoly() + obj.label = 'car' + obj.polygon.append(Point(0, 0)) + obj.polygon.append(Point(1, 0)) + obj.polygon.append(Point(1, 1)) + obj.polygon.append(Point(0, 1)) + + print(type(obj).__name__) + print(obj) diff --git a/ext/cityscapes_scripts/helpers/csHelpers.py b/ext/cityscapes_scripts/helpers/csHelpers.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea2053840b25b0a83787e5840387902dc3235c7 --- /dev/null +++ b/ext/cityscapes_scripts/helpers/csHelpers.py @@ -0,0 +1,129 @@ +#!/usr/bin/python +# +# Various helper methods and includes for Cityscapes +# + +# Python imports +from __future__ import print_function, absolute_import, division +import os +import sys +import getopt +import glob +import math +import json +from collections import namedtuple +import logging +import traceback + +# Image processing +from PIL import Image +from PIL import ImageDraw + +# Numpy for datastructures +import numpy as np + +# Cityscapes modules +# from .annotation import Annotation +from .labels import labels, name2label, id2label, trainId2label, category2labels + + +def printError(message): + """Print an error message and quit""" + print('ERROR: ' + str(message)) + sys.exit(-1) + + +class colors: + """Class for colors""" + RED = '\033[31;1m' + GREEN = '\033[32;1m' + YELLOW = '\033[33;1m' + BLUE = '\033[34;1m' + MAGENTA = '\033[35;1m' + CYAN = '\033[36;1m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + ENDC = '\033[0m' + + +def getColorEntry(val, args): + """Colored value output if colorized flag is activated.""" + + if not args.colorized: + return "" + if not isinstance(val, float) or math.isnan(val): + return colors.ENDC + if (val < .20): + return colors.RED + elif (val < .40): + return colors.YELLOW + elif (val < .60): + return colors.BLUE + elif (val < .80): + return colors.CYAN + else: + return colors.GREEN + + +# Cityscapes files have a typical filename structure +# ___[_]. +# This class contains the individual elements as members +# For the sequence and frame number, the strings are returned, including leading zeros +CsFile = namedtuple('csFile', ['city', 'sequenceNb', 'frameNb', 'type', 'type2', 'ext']) + + +def getCsFileInfo(fileName): + """Returns a CsFile object filled from the info in the given filename""" + baseName = os.path.basename(fileName) + parts = baseName.split('_') + parts = parts[:-1] + parts[-1].split('.') + if not parts: + printError('Cannot parse given filename ({}). Does not seem to be a valid Cityscapes file.'.format(fileName)) + if len(parts) == 5: + csFile = CsFile(*parts[:-1], type2="", ext=parts[-1]) + elif len(parts) == 6: + csFile = CsFile(*parts) + else: + printError('Found {} part(s) in given filename ({}). Expected 5 or 6.'.format(len(parts), fileName)) + + return csFile + + +def getCoreImageFileName(filename): + """Returns the part of Cityscapes filenames that is common to all data types + + e.g. for city_123456_123456_gtFine_polygons.json returns city_123456_123456 + """ + csFile = getCsFileInfo(filename) + return "{}_{}_{}".format(csFile.city, csFile.sequenceNb, csFile.frameNb) + + +def getDirectory(fileName): + """Returns the directory name for the given filename + + e.g. + fileName = "/foo/bar/foobar.txt" + return value is "bar" + Not much error checking though + """ + dirName = os.path.dirname(fileName) + return os.path.basename(dirName) + + +def ensurePath(path): + """Make sure that the given path exists""" + if not path: + return + if not os.path.isdir(path): + os.makedirs(path) + + +def writeDict2JSON(dictName, fileName): + """Write a dictionary as json file""" + with open(fileName, 'w') as f: + f.write(json.dumps(dictName, default=lambda o: o.__dict__, sort_keys=True, indent=4)) + + +# dummy main +if __name__ == "__main__": + printError("Only for include, not executable on its own.") diff --git a/ext/cityscapes_scripts/helpers/labels.py b/ext/cityscapes_scripts/helpers/labels.py new file mode 100644 index 0000000000000000000000000000000000000000..b785efe364307e71938e57aaf9b3e3a7241bfdc7 --- /dev/null +++ b/ext/cityscapes_scripts/helpers/labels.py @@ -0,0 +1,182 @@ +#!/usr/bin/python +# +# Cityscapes labels +# + +from __future__ import print_function, absolute_import, division +from collections import namedtuple + + +#-------------------------------------------------------------------------------- +# Definitions +#-------------------------------------------------------------------------------- + +# a label and all meta information +Label = namedtuple( 'Label' , [ + + 'name' , # The identifier of this label, e.g. 'car', 'person', ... . + # We use them to uniquely name a class + + 'id' , # An integer ID that is associated with this label. + # The IDs are used to represent the label in ground truth images + # An ID of -1 means that this label does not have an ID and thus + # is ignored when creating ground truth images (e.g. license plate). + # Do not modify these IDs, since exactly these IDs are expected by the + # evaluation server. + + 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create + # ground truth images with train IDs, using the tools provided in the + # 'preparation' folder. However, make sure to validate or submit results + # to our evaluation server using the regular IDs above! + # For trainIds, multiple labels might have the same ID. Then, these labels + # are mapped to the same class in the ground truth images. For the inverse + # mapping, we use the label that is defined first in the list below. + # For example, mapping all void-type classes to the same ID in training, + # might make sense for some approaches. + # Max value is 255! + + 'category' , # The name of the category that this label belongs to + + 'categoryId' , # The ID of this category. Used to create ground truth images + # on category level. + + 'hasInstances', # Whether this label distinguishes between single instances or not + + 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored + # during evaluations or not + + 'color' , # The color of this label + ] ) + + +#-------------------------------------------------------------------------------- +# A list of all labels +#-------------------------------------------------------------------------------- + +# Please adapt the train IDs as appropriate for your approach. +# Note that you might want to ignore labels with ID 255 during training. +# Further note that the current train IDs are only a suggestion. You can use whatever you like. +# Make sure to provide your results using the original IDs and not the training IDs. +# Note that many IDs are ignored in evaluation and thus you never need to predict these! + +labels = [ + # name id trainId category catId hasInstances ignoreInEval color + Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), + Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), + Label( 'road' , 7 , 0 + 8, 'flat' , 1 , False , False , (128, 64,128) ), + Label( 'sidewalk' , 8 , 1 + 8, 'flat' , 1 , False , False , (244, 35,232) ), + Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), + Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), + Label( 'building' , 11 , 2 + 8, 'construction' , 2 , False , False , ( 70, 70, 70) ), + Label( 'wall' , 12 , 3 + 8, 'construction' , 2 , False , False , (102,102,156) ), + Label( 'fence' , 13 , 4 + 8, 'construction' , 2 , False , False , (190,153,153) ), + Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), + Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), + Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), + Label( 'pole' , 17 , 5 + 8, 'object' , 3 , False , False , (153,153,153) ), + Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), + Label( 'traffic light' , 19 , 6 + 8, 'object' , 3 , False , False , (250,170, 30) ), + Label( 'traffic sign' , 20 , 7 + 8, 'object' , 3 , False , False , (220,220, 0) ), + Label( 'vegetation' , 21 , 8 + 8, 'nature' , 4 , False , False , (107,142, 35) ), + Label( 'terrain' , 22 , 9 + 8, 'nature' , 4 , False , False , (152,251,152) ), + Label( 'sky' , 23 , 10 + 8, 'sky' , 5 , False , False , ( 70,130,180) ), + Label( 'person' , 24 , 11 - 11 , 'human' , 6 , True , False , (220, 20, 60) ), + Label( 'rider' , 25 , 12 - 11 , 'human' , 6 , True , False , (255, 0, 0) ), + Label( 'car' , 26 , 13 - 11, 'vehicle' , 7 , True , False , ( 0, 0,142) ), + Label( 'truck' , 27 , 14 - 11, 'vehicle' , 7 , True , False , ( 0, 0, 70) ), + Label( 'bus' , 28 , 15 - 11, 'vehicle' , 7 , True , False , ( 0, 60,100) ), + Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), + Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), + Label( 'train' , 31 , 16 - 11, 'vehicle' , 7 , True , False , ( 0, 80,100) ), + Label( 'motorcycle' , 32 , 17 - 11, 'vehicle' , 7 , True , False , ( 0, 0,230) ), + Label( 'bicycle' , 33 , 18 - 11, 'vehicle' , 7 , True , False , (119, 11, 32) ), + Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), +] + + +#-------------------------------------------------------------------------------- +# Create dictionaries for a fast lookup +#-------------------------------------------------------------------------------- + +# Please refer to the main method below for example usages! + +# name to label object +name2label = { label.name : label for label in labels } +# id to label object +id2label = { label.id : label for label in labels } +# trainId to label object +trainId2label = { label.trainId : label for label in reversed(labels) } +# category to list of label objects +category2labels = {} +for label in labels: + category = label.category + if category in category2labels: + category2labels[category].append(label) + else: + category2labels[category] = [label] + +#-------------------------------------------------------------------------------- +# Assure single instance name +#-------------------------------------------------------------------------------- + +# returns the label name that describes a single instance (if possible) +# e.g. input | output +# ---------------------- +# car | car +# cargroup | car +# foo | None +# foogroup | None +# skygroup | None +def assureSingleInstanceName( name ): + # if the name is known, it is not a group + if name in name2label: + return name + # test if the name actually denotes a group + if not name.endswith("group"): + return None + # remove group + name = name[:-len("group")] + # test if the new name exists + if not name in name2label: + return None + # test if the new name denotes a label that actually has instances + if not name2label[name].hasInstances: + return None + # all good then + return name + +#-------------------------------------------------------------------------------- +# Main for testing +#-------------------------------------------------------------------------------- + +# just a dummy main +if __name__ == "__main__": + # Print all the labels + print("List of cityscapes labels:") + print("") + print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' )) + print(" " + ('-' * 98)) + for label in labels: + print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval )) + print("") + + print("Example usages:") + + # Map from name to label + name = 'car' + id = name2label[name].id + print("ID of label '{name}': {id}".format( name=name, id=id )) + + # Map from ID to label + category = id2label[id].category + print("Category of label with ID '{id}': {category}".format( id=id, category=category )) + + # Map from trainID to label + trainId = 0 + name = trainId2label[trainId].name + print("Name of label with trainID '{id}': {name}".format( id=trainId, name=name )) diff --git a/ext/cityscapes_scripts/helpers/labels_cityPersons.py b/ext/cityscapes_scripts/helpers/labels_cityPersons.py new file mode 100644 index 0000000000000000000000000000000000000000..7f81abbc2822071293f98bc84e6585bb30014311 --- /dev/null +++ b/ext/cityscapes_scripts/helpers/labels_cityPersons.py @@ -0,0 +1,61 @@ +#!/usr/bin/python +# +# CityPersons (cp) labels +# + +from __future__ import print_function, absolute_import, division +from collections import namedtuple + + +#-------------------------------------------------------------------------------- +# Definitions +#-------------------------------------------------------------------------------- + +# a label and all meta information +LabelCp = namedtuple( 'LabelCp' , [ + + 'name' , # The identifier of this label, e.g. 'pedestrian', 'rider', ... . + # We use them to uniquely name a class + + 'id' , # An integer ID that is associated with this label. + # The IDs are used to represent the label in ground truth + + 'hasInstances', # Whether this label distinguishes between single instances or not + + 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored + # during evaluations or not + + 'color' , # The color of this label + ] ) + + +#-------------------------------------------------------------------------------- +# A list of all labels +#-------------------------------------------------------------------------------- + +# The 'ignore' label covers representations of humans, e.g. people on posters, reflections etc. +# Each annotation includes both the full bounding box (bbox) as well as a bounding box covering the visible area (bboxVis). +# The latter is obtained automatically from the segmentation masks. + +labelsCp = [ + # name id hasInstances ignoreInEval color + LabelCp( 'ignore' , 0 , False , True , (250,170, 30) ), + LabelCp( 'pedestrian' , 1 , True , False , (220, 20, 60) ), + LabelCp( 'rider' , 2 , True , False , ( 0, 0,142) ), + LabelCp( 'sitting person' , 3 , True , False , (107,142, 35) ), + LabelCp( 'person (other)' , 4 , True , False , (190,153,153) ), + LabelCp( 'person group' , 5 , False , True , (255, 0, 0) ), +] + + +#-------------------------------------------------------------------------------- +# Create dictionaries for a fast lookup +#-------------------------------------------------------------------------------- + +# Please refer to the main method below for example usages! + +# name to label object +name2labelCp = { label.name : label for label in labelsCp } +# id to label object +id2labelCp = { label.id : label for label in labelsCp } + diff --git a/ext/cityscapes_scripts/helpers/version.py b/ext/cityscapes_scripts/helpers/version.py new file mode 100644 index 0000000000000000000000000000000000000000..5f6b64e3a811c43774dd1fc47197514155c6d68f --- /dev/null +++ b/ext/cityscapes_scripts/helpers/version.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python + +import os + +with open(os.path.join(os.path.dirname(__file__), '..', 'VERSION')) as f: + version = f.read().strip() + +if __name__ == "__main__": + print(version) diff --git a/ext/class_names/VIPSeg.py b/ext/class_names/VIPSeg.py new file mode 100644 index 0000000000000000000000000000000000000000..54ab1c886bfd396f5b666ad1035afe9553dc2e27 --- /dev/null +++ b/ext/class_names/VIPSeg.py @@ -0,0 +1,261 @@ +CLASSES = [ + {"id": 0, "name": "wall", "isthing": 0, "color": [120, 120, 120]}, + {"id": 1, "name": "ceiling", "isthing": 0, "color": [180, 120, 120]}, + {"id": 2, "name": "door", "isthing": 1, "color": [6, 230, 230]}, + {"id": 3, "name": "stair", "isthing": 0, "color": [80, 50, 50]}, + {"id": 4, "name": "ladder", "isthing": 1, "color": [4, 200, 3]}, + {"id": 5, "name": "escalator", "isthing": 0, "color": [120, 120, 80]}, + {"id": 6, "name": "Playground_slide", "isthing": 0, "color": [140, 140, 140]}, + {"id": 7, "name": "handrail_or_fence", "isthing": 0, "color": [204, 5, 255]}, + {"id": 8, "name": "window", "isthing": 1, "color": [230, 230, 230]}, + {"id": 9, "name": "rail", "isthing": 0, "color": [4, 250, 7]}, + {"id": 10, "name": "goal", "isthing": 1, "color": [224, 5, 255]}, + {"id": 11, "name": "pillar", "isthing": 0, "color": [235, 255, 7]}, + {"id": 12, "name": "pole", "isthing": 0, "color": [150, 5, 61]}, + {"id": 13, "name": "floor", "isthing": 0, "color": [120, 120, 70]}, + {"id": 14, "name": "ground", "isthing": 0, "color": [8, 255, 51]}, + {"id": 15, "name": "grass", "isthing": 0, "color": [255, 6, 82]}, + {"id": 16, "name": "sand", "isthing": 0, "color": [143, 255, 140]}, + {"id": 17, "name": "athletic_field", "isthing": 0, "color": [204, 255, 4]}, + {"id": 18, "name": "road", "isthing": 0, "color": [255, 51, 7]}, + {"id": 19, "name": "path", "isthing": 0, "color": [204, 70, 3]}, + {"id": 20, "name": "crosswalk", "isthing": 0, "color": [0, 102, 200]}, + {"id": 21, "name": "building", "isthing": 0, "color": [61, 230, 250]}, + {"id": 22, "name": "house", "isthing": 0, "color": [255, 6, 51]}, + {"id": 23, "name": "bridge", "isthing": 0, "color": [11, 102, 255]}, + {"id": 24, "name": "tower", "isthing": 0, "color": [255, 7, 71]}, + {"id": 25, "name": "windmill", "isthing": 0, "color": [255, 9, 224]}, + {"id": 26, "name": "well_or_well_lid", "isthing": 0, "color": [9, 7, 230]}, + {"id": 27, "name": "other_construction", "isthing": 0, "color": [220, 220, 220]}, + {"id": 28, "name": "sky", "isthing": 0, "color": [255, 9, 92]}, + {"id": 29, "name": "mountain", "isthing": 0, "color": [112, 9, 255]}, + {"id": 30, "name": "stone", "isthing": 0, "color": [8, 255, 214]}, + {"id": 31, "name": "wood", "isthing": 0, "color": [7, 255, 224]}, + {"id": 32, "name": "ice", "isthing": 0, "color": [255, 184, 6]}, + {"id": 33, "name": "snowfield", "isthing": 0, "color": [10, 255, 71]}, + {"id": 34, "name": "grandstand", "isthing": 0, "color": [255, 41, 10]}, + {"id": 35, "name": "sea", "isthing": 0, "color": [7, 255, 255]}, + {"id": 36, "name": "river", "isthing": 0, "color": [224, 255, 8]}, + {"id": 37, "name": "lake", "isthing": 0, "color": [102, 8, 255]}, + {"id": 38, "name": "waterfall", "isthing": 0, "color": [255, 61, 6]}, + {"id": 39, "name": "water", "isthing": 0, "color": [255, 194, 7]}, + {"id": 40, "name": "billboard_or_Bulletin_Board", "isthing": 0, "color": [255, 122, 8]}, + {"id": 41, "name": "sculpture", "isthing": 1, "color": [0, 255, 20]}, + {"id": 42, "name": "pipeline", "isthing": 0, "color": [255, 8, 41]}, + {"id": 43, "name": "flag", "isthing": 1, "color": [255, 5, 153]}, + {"id": 44, "name": "parasol_or_umbrella", "isthing": 1, "color": [6, 51, 255]}, + {"id": 45, "name": "cushion_or_carpet", "isthing": 0, "color": [235, 12, 255]}, + {"id": 46, "name": "tent", "isthing": 1, "color": [160, 150, 20]}, + {"id": 47, "name": "roadblock", "isthing": 1, "color": [0, 163, 255]}, + {"id": 48, "name": "car", "isthing": 1, "color": [140, 140, 140]}, + {"id": 49, "name": "bus", "isthing": 1, "color": [250, 10, 15]}, + {"id": 50, "name": "truck", "isthing": 1, "color": [20, 255, 0]}, + {"id": 51, "name": "bicycle", "isthing": 1, "color": [31, 255, 0]}, + {"id": 52, "name": "motorcycle", "isthing": 1, "color": [255, 31, 0]}, + {"id": 53, "name": "wheeled_machine", "isthing": 0, "color": [255, 224, 0]}, + {"id": 54, "name": "ship_or_boat", "isthing": 1, "color": [153, 255, 0]}, + {"id": 55, "name": "raft", "isthing": 1, "color": [0, 0, 255]}, + {"id": 56, "name": "airplane", "isthing": 1, "color": [255, 71, 0]}, + {"id": 57, "name": "tyre", "isthing": 0, "color": [0, 235, 255]}, + {"id": 58, "name": "traffic_light", "isthing": 0, "color": [0, 173, 255]}, + {"id": 59, "name": "lamp", "isthing": 0, "color": [31, 0, 255]}, + {"id": 60, "name": "person", "isthing": 1, "color": [11, 200, 200]}, + {"id": 61, "name": "cat", "isthing": 1, "color": [255, 82, 0]}, + {"id": 62, "name": "dog", "isthing": 1, "color": [0, 255, 245]}, + {"id": 63, "name": "horse", "isthing": 1, "color": [0, 61, 255]}, + {"id": 64, "name": "cattle", "isthing": 1, "color": [0, 255, 112]}, + {"id": 65, "name": "other_animal", "isthing": 1, "color": [0, 255, 133]}, + {"id": 66, "name": "tree", "isthing": 0, "color": [255, 0, 0]}, + {"id": 67, "name": "flower", "isthing": 0, "color": [255, 163, 0]}, + {"id": 68, "name": "other_plant", "isthing": 0, "color": [255, 102, 0]}, + {"id": 69, "name": "toy", "isthing": 0, "color": [194, 255, 0]}, + {"id": 70, "name": "ball_net", "isthing": 0, "color": [0, 143, 255]}, + {"id": 71, "name": "backboard", "isthing": 0, "color": [51, 255, 0]}, + {"id": 72, "name": "skateboard", "isthing": 1, "color": [0, 82, 255]}, + {"id": 73, "name": "bat", "isthing": 0, "color": [0, 255, 41]}, + {"id": 74, "name": "ball", "isthing": 1, "color": [0, 255, 173]}, + {"id": 75, "name": "cupboard_or_showcase_or_storage_rack", "isthing": 0, "color": [10, 0, 255]}, + {"id": 76, "name": "box", "isthing": 1, "color": [173, 255, 0]}, + {"id": 77, "name": "traveling_case_or_trolley_case", "isthing": 1, "color": [0, 255, 153]}, + {"id": 78, "name": "basket", "isthing": 1, "color": [255, 92, 0]}, + {"id": 79, "name": "bag_or_package", "isthing": 1, "color": [255, 0, 255]}, + {"id": 80, "name": "trash_can", "isthing": 0, "color": [255, 0, 245]}, + {"id": 81, "name": "cage", "isthing": 0, "color": [255, 0, 102]}, + {"id": 82, "name": "plate", "isthing": 1, "color": [255, 173, 0]}, + {"id": 83, "name": "tub_or_bowl_or_pot", "isthing": 1, "color": [255, 0, 20]}, + {"id": 84, "name": "bottle_or_cup", "isthing": 1, "color": [255, 184, 184]}, + {"id": 85, "name": "barrel", "isthing": 1, "color": [0, 31, 255]}, + {"id": 86, "name": "fishbowl", "isthing": 1, "color": [0, 255, 61]}, + {"id": 87, "name": "bed", "isthing": 1, "color": [0, 71, 255]}, + {"id": 88, "name": "pillow", "isthing": 1, "color": [255, 0, 204]}, + {"id": 89, "name": "table_or_desk", "isthing": 1, "color": [0, 255, 194]}, + {"id": 90, "name": "chair_or_seat", "isthing": 1, "color": [0, 255, 82]}, + {"id": 91, "name": "bench", "isthing": 1, "color": [0, 10, 255]}, + {"id": 92, "name": "sofa", "isthing": 1, "color": [0, 112, 255]}, + {"id": 93, "name": "shelf", "isthing": 0, "color": [51, 0, 255]}, + {"id": 94, "name": "bathtub", "isthing": 0, "color": [0, 194, 255]}, + {"id": 95, "name": "gun", "isthing": 1, "color": [0, 122, 255]}, + {"id": 96, "name": "commode", "isthing": 1, "color": [0, 255, 163]}, + {"id": 97, "name": "roaster", "isthing": 1, "color": [255, 153, 0]}, + {"id": 98, "name": "other_machine", "isthing": 0, "color": [0, 255, 10]}, + {"id": 99, "name": "refrigerator", "isthing": 1, "color": [255, 112, 0]}, + {"id": 100, "name": "washing_machine", "isthing": 1, "color": [143, 255, 0]}, + {"id": 101, "name": "Microwave_oven", "isthing": 1, "color": [82, 0, 255]}, + {"id": 102, "name": "fan", "isthing": 1, "color": [163, 255, 0]}, + {"id": 103, "name": "curtain", "isthing": 0, "color": [255, 235, 0]}, + {"id": 104, "name": "textiles", "isthing": 0, "color": [8, 184, 170]}, + {"id": 105, "name": "clothes", "isthing": 0, "color": [133, 0, 255]}, + {"id": 106, "name": "painting_or_poster", "isthing": 1, "color": [0, 255, 92]}, + {"id": 107, "name": "mirror", "isthing": 1, "color": [184, 0, 255]}, + {"id": 108, "name": "flower_pot_or_vase", "isthing": 1, "color": [255, 0, 31]}, + {"id": 109, "name": "clock", "isthing": 1, "color": [0, 184, 255]}, + {"id": 110, "name": "book", "isthing": 0, "color": [0, 214, 255]}, + {"id": 111, "name": "tool", "isthing": 0, "color": [255, 0, 112]}, + {"id": 112, "name": "blackboard", "isthing": 0, "color": [92, 255, 0]}, + {"id": 113, "name": "tissue", "isthing": 0, "color": [0, 224, 255]}, + {"id": 114, "name": "screen_or_television", "isthing": 1, "color": [112, 224, 255]}, + {"id": 115, "name": "computer", "isthing": 1, "color": [70, 184, 160]}, + {"id": 116, "name": "printer", "isthing": 1, "color": [163, 0, 255]}, + {"id": 117, "name": "Mobile_phone", "isthing": 1, "color": [153, 0, 255]}, + {"id": 118, "name": "keyboard", "isthing": 1, "color": [71, 255, 0]}, + {"id": 119, "name": "other_electronic_product", "isthing": 0, "color": [255, 0, 163]}, + {"id": 120, "name": "fruit", "isthing": 0, "color": [255, 204, 0]}, + {"id": 121, "name": "food", "isthing": 0, "color": [255, 0, 143]}, + {"id": 122, "name": "instrument", "isthing": 1, "color": [0, 255, 235]}, + {"id": 123, "name": "train", "isthing": 1, "color": [133, 255, 0]} +] + +CLASSES_THING = [ + {'id': 2, 'name': 'door', 'isthing': 1, 'color': [6, 230, 230]}, + {'id': 4, 'name': 'ladder', 'isthing': 1, 'color': [4, 200, 3]}, + {'id': 8, 'name': 'window', 'isthing': 1, 'color': [230, 230, 230]}, + {'id': 10, 'name': 'goal', 'isthing': 1, 'color': [224, 5, 255]}, + {'id': 41, 'name': 'sculpture', 'isthing': 1, 'color': [0, 255, 20]}, + {'id': 43, 'name': 'flag', 'isthing': 1, 'color': [255, 5, 153]}, + {'id': 44, 'name': 'parasol_or_umbrella', 'isthing': 1, 'color': [6, 51, 255]}, + {'id': 46, 'name': 'tent', 'isthing': 1, 'color': [160, 150, 20]}, + {'id': 47, 'name': 'roadblock', 'isthing': 1, 'color': [0, 163, 255]}, + {'id': 48, 'name': 'car', 'isthing': 1, 'color': [140, 140, 140]}, + {'id': 49, 'name': 'bus', 'isthing': 1, 'color': [250, 10, 15]}, + {'id': 50, 'name': 'truck', 'isthing': 1, 'color': [20, 255, 0]}, + {'id': 51, 'name': 'bicycle', 'isthing': 1, 'color': [31, 255, 0]}, + {'id': 52, 'name': 'motorcycle', 'isthing': 1, 'color': [255, 31, 0]}, + {'id': 54, 'name': 'ship_or_boat', 'isthing': 1, 'color': [153, 255, 0]}, + {'id': 55, 'name': 'raft', 'isthing': 1, 'color': [0, 0, 255]}, + {'id': 56, 'name': 'airplane', 'isthing': 1, 'color': [255, 71, 0]}, + {'id': 60, 'name': 'person', 'isthing': 1, 'color': [11, 200, 200]}, + {'id': 61, 'name': 'cat', 'isthing': 1, 'color': [255, 82, 0]}, + {'id': 62, 'name': 'dog', 'isthing': 1, 'color': [0, 255, 245]}, + {'id': 63, 'name': 'horse', 'isthing': 1, 'color': [0, 61, 255]}, + {'id': 64, 'name': 'cattle', 'isthing': 1, 'color': [0, 255, 112]}, + {'id': 65, 'name': 'other_animal', 'isthing': 1, 'color': [0, 255, 133]}, + {'id': 72, 'name': 'skateboard', 'isthing': 1, 'color': [0, 82, 255]}, + {'id': 74, 'name': 'ball', 'isthing': 1, 'color': [0, 255, 173]}, + {'id': 76, 'name': 'box', 'isthing': 1, 'color': [173, 255, 0]}, + {'id': 77, 'name': 'traveling_case_or_trolley_case', 'isthing': 1, 'color': [0, 255, 153]}, + {'id': 78, 'name': 'basket', 'isthing': 1, 'color': [255, 92, 0]}, + {'id': 79, 'name': 'bag_or_package', 'isthing': 1, 'color': [255, 0, 255]}, + {'id': 82, 'name': 'plate', 'isthing': 1, 'color': [255, 173, 0]}, + {'id': 83, 'name': 'tub_or_bowl_or_pot', 'isthing': 1, 'color': [255, 0, 20]}, + {'id': 84, 'name': 'bottle_or_cup', 'isthing': 1, 'color': [255, 184, 184]}, + {'id': 85, 'name': 'barrel', 'isthing': 1, 'color': [0, 31, 255]}, + {'id': 86, 'name': 'fishbowl', 'isthing': 1, 'color': [0, 255, 61]}, + {'id': 87, 'name': 'bed', 'isthing': 1, 'color': [0, 71, 255]}, + {'id': 88, 'name': 'pillow', 'isthing': 1, 'color': [255, 0, 204]}, + {'id': 89, 'name': 'table_or_desk', 'isthing': 1, 'color': [0, 255, 194]}, + {'id': 90, 'name': 'chair_or_seat', 'isthing': 1, 'color': [0, 255, 82]}, + {'id': 91, 'name': 'bench', 'isthing': 1, 'color': [0, 10, 255]}, + {'id': 92, 'name': 'sofa', 'isthing': 1, 'color': [0, 112, 255]}, + {'id': 95, 'name': 'gun', 'isthing': 1, 'color': [0, 122, 255]}, + {'id': 96, 'name': 'commode', 'isthing': 1, 'color': [0, 255, 163]}, + {'id': 97, 'name': 'roaster', 'isthing': 1, 'color': [255, 153, 0]}, + {'id': 99, 'name': 'refrigerator', 'isthing': 1, 'color': [255, 112, 0]}, + {'id': 100, 'name': 'washing_machine', 'isthing': 1, 'color': [143, 255, 0]}, + {'id': 101, 'name': 'Microwave_oven', 'isthing': 1, 'color': [82, 0, 255]}, + {'id': 102, 'name': 'fan', 'isthing': 1, 'color': [163, 255, 0]}, + {'id': 106, 'name': 'painting_or_poster', 'isthing': 1, 'color': [0, 255, 92]}, + {'id': 107, 'name': 'mirror', 'isthing': 1, 'color': [184, 0, 255]}, + {'id': 108, 'name': 'flower_pot_or_vase', 'isthing': 1, 'color': [255, 0, 31]}, + {'id': 109, 'name': 'clock', 'isthing': 1, 'color': [0, 184, 255]}, + {'id': 114, 'name': 'screen_or_television', 'isthing': 1, 'color': [112, 224, 255]}, + {'id': 115, 'name': 'computer', 'isthing': 1, 'color': [70, 184, 160]}, + {'id': 116, 'name': 'printer', 'isthing': 1, 'color': [163, 0, 255]}, + {'id': 117, 'name': 'Mobile_phone', 'isthing': 1, 'color': [153, 0, 255]}, + {'id': 118, 'name': 'keyboard', 'isthing': 1, 'color': [71, 255, 0]}, + {'id': 122, 'name': 'instrument', 'isthing': 1, 'color': [0, 255, 235]}, + {'id': 123, 'name': 'train', 'isthing': 1, 'color': [133, 255, 0]} +] + +CLASSES_STUFF = [ + {'id': 0, 'name': 'wall', 'isthing': 0, 'color': [120, 120, 120]}, + {'id': 1, 'name': 'ceiling', 'isthing': 0, 'color': [180, 120, 120]}, + {'id': 3, 'name': 'stair', 'isthing': 0, 'color': [80, 50, 50]}, + {'id': 5, 'name': 'escalator', 'isthing': 0, 'color': [120, 120, 80]}, + {'id': 6, 'name': 'Playground_slide', 'isthing': 0, 'color': [140, 140, 140]}, + {'id': 7, 'name': 'handrail_or_fence', 'isthing': 0, 'color': [204, 5, 255]}, + {'id': 9, 'name': 'rail', 'isthing': 0, 'color': [4, 250, 7]}, + {'id': 11, 'name': 'pillar', 'isthing': 0, 'color': [235, 255, 7]}, + {'id': 12, 'name': 'pole', 'isthing': 0, 'color': [150, 5, 61]}, + {'id': 13, 'name': 'floor', 'isthing': 0, 'color': [120, 120, 70]}, + {'id': 14, 'name': 'ground', 'isthing': 0, 'color': [8, 255, 51]}, + {'id': 15, 'name': 'grass', 'isthing': 0, 'color': [255, 6, 82]}, + {'id': 16, 'name': 'sand', 'isthing': 0, 'color': [143, 255, 140]}, + {'id': 17, 'name': 'athletic_field', 'isthing': 0, 'color': [204, 255, 4]}, + {'id': 18, 'name': 'road', 'isthing': 0, 'color': [255, 51, 7]}, + {'id': 19, 'name': 'path', 'isthing': 0, 'color': [204, 70, 3]}, + {'id': 20, 'name': 'crosswalk', 'isthing': 0, 'color': [0, 102, 200]}, + {'id': 21, 'name': 'building', 'isthing': 0, 'color': [61, 230, 250]}, + {'id': 22, 'name': 'house', 'isthing': 0, 'color': [255, 6, 51]}, + {'id': 23, 'name': 'bridge', 'isthing': 0, 'color': [11, 102, 255]}, + {'id': 24, 'name': 'tower', 'isthing': 0, 'color': [255, 7, 71]}, + {'id': 25, 'name': 'windmill', 'isthing': 0, 'color': [255, 9, 224]}, + {'id': 26, 'name': 'well_or_well_lid', 'isthing': 0, 'color': [9, 7, 230]}, + {'id': 27, 'name': 'other_construction', 'isthing': 0, 'color': [220, 220, 220]}, + {'id': 28, 'name': 'sky', 'isthing': 0, 'color': [255, 9, 92]}, + {'id': 29, 'name': 'mountain', 'isthing': 0, 'color': [112, 9, 255]}, + {'id': 30, 'name': 'stone', 'isthing': 0, 'color': [8, 255, 214]}, + {'id': 31, 'name': 'wood', 'isthing': 0, 'color': [7, 255, 224]}, + {'id': 32, 'name': 'ice', 'isthing': 0, 'color': [255, 184, 6]}, + {'id': 33, 'name': 'snowfield', 'isthing': 0, 'color': [10, 255, 71]}, + {'id': 34, 'name': 'grandstand', 'isthing': 0, 'color': [255, 41, 10]}, + {'id': 35, 'name': 'sea', 'isthing': 0, 'color': [7, 255, 255]}, + {'id': 36, 'name': 'river', 'isthing': 0, 'color': [224, 255, 8]}, + {'id': 37, 'name': 'lake', 'isthing': 0, 'color': [102, 8, 255]}, + {'id': 38, 'name': 'waterfall', 'isthing': 0, 'color': [255, 61, 6]}, + {'id': 39, 'name': 'water', 'isthing': 0, 'color': [255, 194, 7]}, + {'id': 40, 'name': 'billboard_or_Bulletin_Board', 'isthing': 0, 'color': [255, 122, 8]}, + {'id': 42, 'name': 'pipeline', 'isthing': 0, 'color': [255, 8, 41]}, + {'id': 45, 'name': 'cushion_or_carpet', 'isthing': 0, 'color': [235, 12, 255]}, + {'id': 53, 'name': 'wheeled_machine', 'isthing': 0, 'color': [255, 224, 0]}, + {'id': 57, 'name': 'tyre', 'isthing': 0, 'color': [0, 235, 255]}, + {'id': 58, 'name': 'traffic_light', 'isthing': 0, 'color': [0, 173, 255]}, + {'id': 59, 'name': 'lamp', 'isthing': 0, 'color': [31, 0, 255]}, + {'id': 66, 'name': 'tree', 'isthing': 0, 'color': [255, 0, 0]}, + {'id': 67, 'name': 'flower', 'isthing': 0, 'color': [255, 163, 0]}, + {'id': 68, 'name': 'other_plant', 'isthing': 0, 'color': [255, 102, 0]}, + {'id': 69, 'name': 'toy', 'isthing': 0, 'color': [194, 255, 0]}, + {'id': 70, 'name': 'ball_net', 'isthing': 0, 'color': [0, 143, 255]}, + {'id': 71, 'name': 'backboard', 'isthing': 0, 'color': [51, 255, 0]}, + {'id': 73, 'name': 'bat', 'isthing': 0, 'color': [0, 255, 41]}, + {'id': 75, 'name': 'cupboard_or_showcase_or_storage_rack', 'isthing': 0, 'color': [10, 0, 255]}, + {'id': 80, 'name': 'trash_can', 'isthing': 0, 'color': [255, 0, 245]}, + {'id': 81, 'name': 'cage', 'isthing': 0, 'color': [255, 0, 102]}, + {'id': 93, 'name': 'shelf', 'isthing': 0, 'color': [51, 0, 255]}, + {'id': 94, 'name': 'bathtub', 'isthing': 0, 'color': [0, 194, 255]}, + {'id': 98, 'name': 'other_machine', 'isthing': 0, 'color': [0, 255, 10]}, + {'id': 103, 'name': 'curtain', 'isthing': 0, 'color': [255, 235, 0]}, + {'id': 104, 'name': 'textiles', 'isthing': 0, 'color': [8, 184, 170]}, + {'id': 105, 'name': 'clothes', 'isthing': 0, 'color': [133, 0, 255]}, + {'id': 110, 'name': 'book', 'isthing': 0, 'color': [0, 214, 255]}, + {'id': 111, 'name': 'tool', 'isthing': 0, 'color': [255, 0, 112]}, + {'id': 112, 'name': 'blackboard', 'isthing': 0, 'color': [92, 255, 0]}, + {'id': 113, 'name': 'tissue', 'isthing': 0, 'color': [0, 224, 255]}, + {'id': 119, 'name': 'other_electronic_product', 'isthing': 0, 'color': [255, 0, 163]}, + {'id': 120, 'name': 'fruit', 'isthing': 0, 'color': [255, 204, 0]}, + {'id': 121, 'name': 'food', 'isthing': 0, 'color': [255, 0, 143]} +] + +COCO_THINGS = [itm['name'] for itm in CLASSES_THING] +COCO_STUFF = [itm['name'] for itm in CLASSES_STUFF] +COCO_CLASSES = [*COCO_THINGS, *COCO_STUFF] +PALETTE = [*[itm['color'] for itm in CLASSES_THING], *[itm['color'] for itm in CLASSES_STUFF]] diff --git a/ext/davis2017/__init__.py b/ext/davis2017/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2fb263ea94c650a6ac55ef374f2cea1b8c96a1a5 --- /dev/null +++ b/ext/davis2017/__init__.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import + +__version__ = '0.1.0' diff --git a/ext/davis2017/davis.py b/ext/davis2017/davis.py new file mode 100644 index 0000000000000000000000000000000000000000..d831be61d5ce240374c099405cbbc62fb1249f3d --- /dev/null +++ b/ext/davis2017/davis.py @@ -0,0 +1,122 @@ +import os +from glob import glob +from collections import defaultdict +import numpy as np +from PIL import Image + + +class DAVIS(object): + SUBSET_OPTIONS = ['train', 'val', 'test-dev', 'test-challenge'] + TASKS = ['semi-supervised', 'unsupervised'] + DATASET_WEB = 'https://davischallenge.org/davis2017/code.html' + VOID_LABEL = 255 + + def __init__(self, root, task='unsupervised', subset='val', sequences='all', resolution='480p', codalab=False): + """ + Class to read the DAVIS dataset + :param root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. + :param task: Task to load the annotations, choose between semi-supervised or unsupervised. + :param subset: Set to load the annotations + :param sequences: Sequences to consider, 'all' to use all the sequences in a set. + :param resolution: Specify the resolution to use the dataset, choose between '480' and 'Full-Resolution' + """ + if subset not in self.SUBSET_OPTIONS: + raise ValueError(f'Subset should be in {self.SUBSET_OPTIONS}') + if task not in self.TASKS: + raise ValueError(f'The only tasks that are supported are {self.TASKS}') + + self.task = task + self.subset = subset + self.root = root + self.img_path = os.path.join(self.root, 'JPEGImages', resolution) + annotations_folder = 'Annotations' if task == 'semi-supervised' else 'Annotations_unsupervised' + self.mask_path = os.path.join(self.root, annotations_folder, resolution) + year = '2019' if task == 'unsupervised' and (subset == 'test-dev' or subset == 'test-challenge') else '2017' + self.imagesets_path = os.path.join(self.root, 'ImageSets', year) + + self._check_directories() + + if sequences == 'all': + with open(os.path.join(self.imagesets_path, f'{self.subset}.txt'), 'r') as f: + tmp = f.readlines() + sequences_names = [x.strip() for x in tmp] + else: + sequences_names = sequences if isinstance(sequences, list) else [sequences] + self.sequences = defaultdict(dict) + + for seq in sequences_names: + images = np.sort(glob(os.path.join(self.img_path, seq, '*.jpg'))).tolist() + if len(images) == 0 and not codalab: + raise FileNotFoundError(f'Images for sequence {seq} not found.') + self.sequences[seq]['images'] = images + masks = np.sort(glob(os.path.join(self.mask_path, seq, '*.png'))).tolist() + masks.extend([-1] * (len(images) - len(masks))) + self.sequences[seq]['masks'] = masks + + def _check_directories(self): + if not os.path.exists(self.root): + raise FileNotFoundError(f'DAVIS not found in the specified directory, download it from {self.DATASET_WEB}') + if not os.path.exists(os.path.join(self.imagesets_path, f'{self.subset}.txt')): + raise FileNotFoundError(f'Subset sequences list for {self.subset} not found, download the missing subset ' + f'for the {self.task} task from {self.DATASET_WEB}') + if self.subset in ['train', 'val'] and not os.path.exists(self.mask_path): + raise FileNotFoundError(f'Annotations folder for the {self.task} task not found, download it from {self.DATASET_WEB}') + + def get_frames(self, sequence): + for img, msk in zip(self.sequences[sequence]['images'], self.sequences[sequence]['masks']): + image = np.array(Image.open(img)) + mask = None if msk is None else np.array(Image.open(msk)) + yield image, mask + + def _get_all_elements(self, sequence, obj_type): + obj = np.array(Image.open(self.sequences[sequence][obj_type][0])) + all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape)) + obj_id = [] + for i, obj in enumerate(self.sequences[sequence][obj_type]): + all_objs[i, ...] = np.array(Image.open(obj)) + obj_id.append(''.join(obj.split('/')[-1].split('.')[:-1])) + return all_objs, obj_id + + def get_all_images(self, sequence): + return self._get_all_elements(sequence, 'images') + + def get_all_masks(self, sequence, separate_objects_masks=False): + masks, masks_id = self._get_all_elements(sequence, 'masks') + masks_void = np.zeros_like(masks) + + # Separate void and object masks + for i in range(masks.shape[0]): + masks_void[i, ...] = masks[i, ...] == 255 + masks[i, masks[i, ...] == 255] = 0 + + if separate_objects_masks: + num_objects = int(np.max(masks[0, ...])) + tmp = np.ones((num_objects, *masks.shape)) + tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] + masks = (tmp == masks[None, ...]) + masks = masks > 0 + return masks, masks_void, masks_id + + def get_sequences(self): + for seq in self.sequences: + yield seq + + +if __name__ == '__main__': + from matplotlib import pyplot as plt + + only_first_frame = True + subsets = ['train', 'val'] + + for s in subsets: + dataset = DAVIS(root='/home/csergi/scratch2/Databases/DAVIS2017_private', subset=s) + for seq in dataset.get_sequences(): + g = dataset.get_frames(seq) + img, mask = next(g) + plt.subplot(2, 1, 1) + plt.title(seq) + plt.imshow(img) + plt.subplot(2, 1, 2) + plt.imshow(mask) + plt.show(block=True) + diff --git a/ext/davis2017/evaluation.py b/ext/davis2017/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..17d30e4428b76d8990f1fbce070db5afcb4583ea --- /dev/null +++ b/ext/davis2017/evaluation.py @@ -0,0 +1,110 @@ +import sys +from tqdm import tqdm +import warnings +warnings.filterwarnings("ignore", category=RuntimeWarning) + +import numpy as np +from ext.davis2017.davis import DAVIS +from ext.davis2017.metrics import db_eval_boundary, db_eval_iou +from ext.davis2017 import utils +from ext.davis2017.results import Results +from scipy.optimize import linear_sum_assignment + + +class DAVISEvaluation(object): + def __init__(self, davis_root, task, gt_set, sequences='all', codalab=False): + """ + Class to evaluate DAVIS sequences from a certain set and for a certain task + :param davis_root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. + :param task: Task to compute the evaluation, chose between semi-supervised or unsupervised. + :param gt_set: Set to compute the evaluation + :param sequences: Sequences to consider for the evaluation, 'all' to use all the sequences in a set. + """ + self.davis_root = davis_root + self.task = task + self.dataset = DAVIS(root=davis_root, task=task, subset=gt_set, sequences=sequences, codalab=codalab) + + @staticmethod + def _evaluate_semisupervised(all_gt_masks, all_res_masks, all_void_masks, metric): + if all_res_masks.shape[0] > all_gt_masks.shape[0]: + sys.stdout.write("\nIn your PNG files there is an index higher than the number of objects in the sequence!") + sys.exit() + elif all_res_masks.shape[0] < all_gt_masks.shape[0]: + zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:])) + all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0) + j_metrics_res, f_metrics_res = np.zeros(all_gt_masks.shape[:2]), np.zeros(all_gt_masks.shape[:2]) + for ii in range(all_gt_masks.shape[0]): + if 'J' in metric: + j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks) + if 'F' in metric: + f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks) + return j_metrics_res, f_metrics_res + + @staticmethod + def _evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric, max_n_proposals=20): + if all_res_masks.shape[0] > max_n_proposals: + sys.stdout.write(f"\nIn your PNG files there is an index higher than the maximum number ({max_n_proposals}) of proposals allowed!") + sys.exit() + elif all_res_masks.shape[0] < all_gt_masks.shape[0]: + zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:])) + all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0) + j_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1])) + f_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1])) + for ii in range(all_gt_masks.shape[0]): + for jj in range(all_res_masks.shape[0]): + if 'J' in metric: + j_metrics_res[jj, ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks) + if 'F' in metric: + f_metrics_res[jj, ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks) + if 'J' in metric and 'F' in metric: + all_metrics = (np.mean(j_metrics_res, axis=2) + np.mean(f_metrics_res, axis=2)) / 2 + else: + all_metrics = np.mean(j_metrics_res, axis=2) if 'J' in metric else np.mean(f_metrics_res, axis=2) + row_ind, col_ind = linear_sum_assignment(-all_metrics) + return j_metrics_res[row_ind, col_ind, :], f_metrics_res[row_ind, col_ind, :] + + def evaluate(self, res_path, metric=('J', 'F'), debug=False): + metric = metric if isinstance(metric, tuple) or isinstance(metric, list) else [metric] + if 'T' in metric: + raise ValueError('Temporal metric not supported!') + if 'J' not in metric and 'F' not in metric: + raise ValueError('Metric possible values are J for IoU or F for Boundary') + + # Containers + metrics_res = {} + if 'J' in metric: + metrics_res['J'] = {"M": [], "R": [], "D": [], "M_per_object": {}} + if 'F' in metric: + metrics_res['F'] = {"M": [], "R": [], "D": [], "M_per_object": {}} + + # Sweep all sequences + results = Results(root_dir=res_path) + for seq in tqdm(list(self.dataset.get_sequences())): + all_gt_masks, all_void_masks, all_masks_id = self.dataset.get_all_masks(seq, True) + if self.task == 'semi-supervised': + all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1] + all_res_masks = results.read_masks(seq, all_masks_id) + if self.task == 'unsupervised': + j_metrics_res, f_metrics_res = self._evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric) + elif self.task == 'semi-supervised': + j_metrics_res, f_metrics_res = self._evaluate_semisupervised(all_gt_masks, all_res_masks, None, metric) + for ii in range(all_gt_masks.shape[0]): + seq_name = f'{seq}_{ii+1}' + if 'J' in metric: + [JM, JR, JD] = utils.db_statistics(j_metrics_res[ii]) + metrics_res['J']["M"].append(JM) + metrics_res['J']["R"].append(JR) + metrics_res['J']["D"].append(JD) + metrics_res['J']["M_per_object"][seq_name] = JM + if 'F' in metric: + [FM, FR, FD] = utils.db_statistics(f_metrics_res[ii]) + metrics_res['F']["M"].append(FM) + metrics_res['F']["R"].append(FR) + metrics_res['F']["D"].append(FD) + metrics_res['F']["M_per_object"][seq_name] = FM + + # Show progress + if debug: + sys.stdout.write(seq + '\n') + sys.stdout.flush() + return metrics_res diff --git a/ext/davis2017/metrics.py b/ext/davis2017/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb2724177cc21de5470233b29bb0360d848a823 --- /dev/null +++ b/ext/davis2017/metrics.py @@ -0,0 +1,197 @@ +import math +import numpy as np +import cv2 + + +def db_eval_iou(annotation, segmentation, void_pixels=None): + """ Compute region similarity as the Jaccard Index. + Arguments: + annotation (ndarray): binary annotation map. + segmentation (ndarray): binary segmentation map. + void_pixels (ndarray): optional mask with void pixels + + Return: + jaccard (float): region similarity + """ + assert annotation.shape == segmentation.shape, \ + f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.' + annotation = annotation.astype(bool) + segmentation = segmentation.astype(bool) + + if void_pixels is not None: + assert annotation.shape == void_pixels.shape, \ + f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.' + void_pixels = void_pixels.astype(bool) + else: + void_pixels = np.zeros_like(segmentation) + + # Intersection between all sets + inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1)) + union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1)) + + j = inters / union + if j.ndim == 0: + j = 1 if np.isclose(union, 0) else j + else: + j[np.isclose(union, 0)] = 1 + return j + + +def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008): + assert annotation.shape == segmentation.shape + if void_pixels is not None: + assert annotation.shape == void_pixels.shape + if annotation.ndim == 3: + n_frames = annotation.shape[0] + f_res = np.zeros(n_frames) + for frame_id in range(n_frames): + void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ] + f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th) + elif annotation.ndim == 2: + f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th) + else: + raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions') + return f_res + + +def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008): + """ + Compute mean,recall and decay from per-frame evaluation. + Calculates precision/recall for boundaries between foreground_mask and + gt_mask using morphological operators to speed it up. + + Arguments: + foreground_mask (ndarray): binary segmentation image. + gt_mask (ndarray): binary annotated image. + void_pixels (ndarray): optional mask with void pixels + + Returns: + F (float): boundaries F-measure + """ + assert np.atleast_3d(foreground_mask).shape[2] == 1 + if void_pixels is not None: + void_pixels = void_pixels.astype(bool) + else: + void_pixels = np.zeros_like(foreground_mask).astype(bool) + + bound_pix = bound_th if bound_th >= 1 else \ + np.ceil(bound_th * np.linalg.norm(foreground_mask.shape)) + + # Get the pixel boundaries of both masks + fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels)) + gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels)) + + from skimage.morphology import disk + + # fg_dil = binary_dilation(fg_boundary, disk(bound_pix)) + fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) + # gt_dil = binary_dilation(gt_boundary, disk(bound_pix)) + gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) + + # Get the intersection + gt_match = gt_boundary * fg_dil + fg_match = fg_boundary * gt_dil + + # Area of the intersection + n_fg = np.sum(fg_boundary) + n_gt = np.sum(gt_boundary) + + # % Compute precision and recall + if n_fg == 0 and n_gt > 0: + precision = 1 + recall = 0 + elif n_fg > 0 and n_gt == 0: + precision = 0 + recall = 1 + elif n_fg == 0 and n_gt == 0: + precision = 1 + recall = 1 + else: + precision = np.sum(fg_match) / float(n_fg) + recall = np.sum(gt_match) / float(n_gt) + + # Compute F measure + if precision + recall == 0: + F = 0 + else: + F = 2 * precision * recall / (precision + recall) + + return F + + +def _seg2bmap(seg, width=None, height=None): + """ + From a segmentation, compute a binary boundary map with 1 pixel wide + boundaries. The boundary pixels are offset by 1/2 pixel towards the + origin from the actual segment boundary. + Arguments: + seg : Segments labeled from 1..k. + width : Width of desired bmap <= seg.shape[1] + height : Height of desired bmap <= seg.shape[0] + Returns: + bmap (ndarray): Binary boundary map. + David Martin + January 2003 + """ + + seg = seg.astype(bool) + seg[seg > 0] = 1 + + assert np.atleast_3d(seg).shape[2] == 1 + + width = seg.shape[1] if width is None else width + height = seg.shape[0] if height is None else height + + h, w = seg.shape[:2] + + ar1 = float(width) / float(height) + ar2 = float(w) / float(h) + + assert not ( + width > w | height > h | abs(ar1 - ar2) > 0.01 + ), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height) + + e = np.zeros_like(seg) + s = np.zeros_like(seg) + se = np.zeros_like(seg) + + e[:, :-1] = seg[:, 1:] + s[:-1, :] = seg[1:, :] + se[:-1, :-1] = seg[1:, 1:] + + b = seg ^ e | seg ^ s | seg ^ se + b[-1, :] = seg[-1, :] ^ e[-1, :] + b[:, -1] = seg[:, -1] ^ s[:, -1] + b[-1, -1] = 0 + + if w == width and h == height: + bmap = b + else: + bmap = np.zeros((height, width)) + for x in range(w): + for y in range(h): + if b[y, x]: + j = 1 + math.floor((y - 1) + height / h) + i = 1 + math.floor((x - 1) + width / h) + bmap[j, i] = 1 + + return bmap + + +if __name__ == '__main__': + from davis2017.davis import DAVIS + from davis2017.results import Results + + dataset = DAVIS(root='input_dir/ref', subset='val', sequences='aerobatics') + results = Results(root_dir='examples/osvos') + # Test timing F measure + for seq in dataset.get_sequences(): + all_gt_masks, _, all_masks_id = dataset.get_all_masks(seq, True) + all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1] + all_res_masks = results.read_masks(seq, all_masks_id) + f_metrics_res = np.zeros(all_gt_masks.shape[:2]) + for ii in range(all_gt_masks.shape[0]): + f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...]) + + # Run using to profile code: python -m cProfile -o f_measure.prof metrics.py + # snakeviz f_measure.prof diff --git a/ext/davis2017/results.py b/ext/davis2017/results.py new file mode 100644 index 0000000000000000000000000000000000000000..a88d3dae6ff1d2e68fbc1cde369a5dbfaa8c87af --- /dev/null +++ b/ext/davis2017/results.py @@ -0,0 +1,52 @@ +import os +import numpy as np +from PIL import Image, ImagePalette +import sys + + +davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0' +mose_palette = b'\x00\x00\x00\xe4\x1a\x1c7~\xb8M\xafJ\x98N\xa3\xff\x7f\x00\xff\xff3\xa6V(\xf7\x81\xbf\x99\x99\x99f\xc2\xa5\xfc\x8db\x8d\xa0\xcb\xe7\x8a\xc3\xa6\xd8T\xff\xd9/\xe5\xc4\x94\xb3\xb3\xb3\x8d\xd3\xc7\xff\xff\xb3\xbe\xba\xda\xfb\x80r\x80\xb1\xd3\xfd\xb4b\xb3\xdei\xfc\xcd\xe5\xd9\xd9\xd9\xbc\x80\xbd\xcc\xeb\xc5\xff\xedo' + +class Results(object): + def __init__(self, root_dir): + self.root_dir = root_dir + + def _read_mask(self, sequence, frame_id): + try: + mask_path = os.path.join(self.root_dir, sequence, f'{frame_id}.png') + # BUGFIX + # There is a bug in the codebase + # Here is a compensation. + img = Image.open(mask_path) + if img.mode != 'P': + img_color = np.array(img) + h, w, three = img_color.shape + assert three == 3 + + img_new = np.ones((h, w), dtype=np.uint8) * 255 + color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy() + for i in range(10): + cur_color = color_map_np[i] + mask = np.all(img_color == cur_color, axis=-1) + img_new[mask] = i + assert not np.all(img_new == 255).any() + img = img_new + # BUGFIX + return np.array(img) + except IOError as err: + sys.stdout.write(sequence + " frame %s not found!\n" % frame_id) + sys.stdout.write("The frames have to be indexed PNG files placed inside the corespondent sequence " + "folder.\nThe indexes have to match with the initial frame.\n") + sys.stderr.write("IOError: " + err.strerror + "\n") + sys.exit() + + def read_masks(self, sequence, masks_id): + mask_0 = self._read_mask(sequence, masks_id[0]) + masks = np.zeros((len(masks_id), *mask_0.shape)) + for ii, m in enumerate(masks_id): + masks[ii, ...] = self._read_mask(sequence, m) + num_objects = int(np.max(masks)) + tmp = np.ones((num_objects, *masks.shape)) + tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] + masks = (tmp == masks[None, ...]) > 0 + return masks diff --git a/ext/davis2017/utils.py b/ext/davis2017/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..22e980dd6a6e9e7184073545b22f703b45c059af --- /dev/null +++ b/ext/davis2017/utils.py @@ -0,0 +1,174 @@ +import os +import errno +import numpy as np +from PIL import Image +import warnings +from ext.davis2017.davis import DAVIS + + +def _pascal_color_map(N=256, normalized=False): + """ + Python implementation of the color map function for the PASCAL VOC data set. + Official Matlab version can be found in the PASCAL VOC devkit + http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit + """ + + def bitget(byteval, idx): + return (byteval & (1 << idx)) != 0 + + dtype = 'float32' if normalized else 'uint8' + cmap = np.zeros((N, 3), dtype=dtype) + for i in range(N): + r = g = b = 0 + c = i + for j in range(8): + r = r | (bitget(c, 0) << 7 - j) + g = g | (bitget(c, 1) << 7 - j) + b = b | (bitget(c, 2) << 7 - j) + c = c >> 3 + + cmap[i] = np.array([r, g, b]) + + cmap = cmap / 255 if normalized else cmap + return cmap + + +def overlay_semantic_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None): + im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int) + if im.shape[:-1] != ann.shape: + raise ValueError('First two dimensions of `im` and `ann` must match') + if im.shape[-1] != 3: + raise ValueError('im must have three channels at the 3 dimension') + + colors = colors or _pascal_color_map() + colors = np.asarray(colors, dtype=np.uint8) + + mask = colors[ann] + fg = im * alpha + (1 - alpha) * mask + + img = im.copy() + img[ann > 0] = fg[ann > 0] + + if contour_thickness: # pragma: no cover + import cv2 + for obj_id in np.unique(ann[ann > 0]): + contours = cv2.findContours((ann == obj_id).astype( + np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] + cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(), + contour_thickness) + return img + + +def generate_obj_proposals(davis_root, subset, num_proposals, save_path): + dataset = DAVIS(davis_root, subset=subset, codalab=True) + for seq in dataset.get_sequences(): + save_dir = os.path.join(save_path, seq) + if os.path.exists(save_dir): + continue + all_gt_masks, all_masks_id = dataset.get_all_masks(seq, True) + img_size = all_gt_masks.shape[2:] + num_rows = int(np.ceil(np.sqrt(num_proposals))) + proposals = np.zeros((num_proposals, len(all_masks_id), *img_size)) + height_slices = np.floor(np.arange(0, img_size[0] + 1, img_size[0]/num_rows)).astype(np.uint).tolist() + width_slices = np.floor(np.arange(0, img_size[1] + 1, img_size[1]/num_rows)).astype(np.uint).tolist() + ii = 0 + prev_h, prev_w = 0, 0 + for h in height_slices[1:]: + for w in width_slices[1:]: + proposals[ii, :, prev_h:h, prev_w:w] = 1 + prev_w = w + ii += 1 + if ii == num_proposals: + break + prev_h, prev_w = h, 0 + if ii == num_proposals: + break + + os.makedirs(save_dir, exist_ok=True) + for i, mask_id in enumerate(all_masks_id): + mask = np.sum(proposals[:, i, ...] * np.arange(1, proposals.shape[0] + 1)[:, None, None], axis=0) + save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) + + +def generate_random_permutation_gt_obj_proposals(davis_root, subset, save_path): + dataset = DAVIS(davis_root, subset=subset, codalab=True) + for seq in dataset.get_sequences(): + gt_masks, all_masks_id = dataset.get_all_masks(seq, True) + obj_swap = np.random.permutation(np.arange(gt_masks.shape[0])) + gt_masks = gt_masks[obj_swap, ...] + save_dir = os.path.join(save_path, seq) + os.makedirs(save_dir, exist_ok=True) + for i, mask_id in enumerate(all_masks_id): + mask = np.sum(gt_masks[:, i, ...] * np.arange(1, gt_masks.shape[0] + 1)[:, None, None], axis=0) + save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) + + +def color_map(N=256, normalized=False): + def bitget(byteval, idx): + return ((byteval & (1 << idx)) != 0) + + dtype = 'float32' if normalized else 'uint8' + cmap = np.zeros((N, 3), dtype=dtype) + for i in range(N): + r = g = b = 0 + c = i + for j in range(8): + r = r | (bitget(c, 0) << 7-j) + g = g | (bitget(c, 1) << 7-j) + b = b | (bitget(c, 2) << 7-j) + c = c >> 3 + + cmap[i] = np.array([r, g, b]) + + cmap = cmap/255 if normalized else cmap + return cmap + + +def save_mask(mask, img_path): + if np.max(mask) > 255: + raise ValueError('Maximum id pixel value is 255') + mask_img = Image.fromarray(mask.astype(np.uint8)) + mask_img.putpalette(color_map().flatten().tolist()) + mask_img.save(img_path) + + +def db_statistics(per_frame_values): + """ Compute mean,recall and decay from per-frame evaluation. + Arguments: + per_frame_values (ndarray): per-frame evaluation + + Returns: + M,O,D (float,float,float): + return evaluation statistics: mean,recall,decay. + """ + + # strip off nan values + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + M = np.nanmean(per_frame_values) + O = np.nanmean(per_frame_values > 0.5) + + N_bins = 4 + ids = np.round(np.linspace(1, len(per_frame_values), N_bins + 1) + 1e-10) - 1 + ids = ids.astype(np.uint8) + + D_bins = [per_frame_values[ids[i]:ids[i + 1] + 1] for i in range(0, 4)] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + D = np.nanmean(D_bins[0]) - np.nanmean(D_bins[3]) + + return M, O, D + + +def list_files(dir, extension=".png"): + return [os.path.splitext(file_)[0] for file_ in os.listdir(dir) if file_.endswith(extension)] + + +def force_symlink(file1, file2): + try: + os.symlink(file1, file2) + except OSError as e: + if e.errno == errno.EEXIST: + os.remove(file2) + os.symlink(file1, file2) diff --git a/ext/meta/sam_meta.py b/ext/meta/sam_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..aba1fa2daf3d24ba83cc54996ac7b5bf305e02e1 --- /dev/null +++ b/ext/meta/sam_meta.py @@ -0,0 +1,41 @@ +meta_dict = { + 'vit_h': dict( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + # common + prompt_embed_dim=256, + image_size=1024, + vit_patch_size=16, + image_embedding_size=64 + ), + 'vit_l': dict( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + # common + prompt_embed_dim=256, + image_size=1024, + vit_patch_size=16, + image_embedding_size=64 + ), + 'vit_b': dict( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + # common + prompt_embed_dim=256, + image_size=1024, + vit_patch_size=16, + image_embedding_size=64 + ) +} + +checkpoint_dict = { + 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', + 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', + 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth', +} diff --git a/ext/open_clip/__init__.py b/ext/open_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb1199b8aa87a919abff1bd0020c6624757ac62 --- /dev/null +++ b/ext/open_clip/__init__.py @@ -0,0 +1,15 @@ +from .coca_model import CoCa +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss +from .factory import list_models, add_model_config, get_model_config, load_checkpoint +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ + get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub +from .tokenizer import SimpleTokenizer, tokenize, decode +from .transform import image_transform, AugmentationCfg +from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy +from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES diff --git a/ext/open_clip/bpe_simple_vocab_16e6.txt.gz b/ext/open_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/ext/open_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/ext/open_clip/coca_model.py b/ext/open_clip/coca_model.py new file mode 100644 index 0000000000000000000000000000000000000000..039453af70d1c865dd7cc6016f732aff2f7dc3d2 --- /dev/null +++ b/ext/open_clip/coca_model.py @@ -0,0 +1,458 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np +from dataclasses import dataclass + +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + MultimodalTransformer, +) +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + +try: + from transformers import ( + BeamSearchScorer, + LogitsProcessorList, + TopPLogitsWarper, + TopKLogitsWarper, + RepetitionPenaltyLogitsProcessor, + MinLengthLogitsProcessor, + MaxLengthCriteria, + StoppingCriteriaList + ) + + GENERATION_TYPES = { + "top_k": TopKLogitsWarper, + "top_p": TopPLogitsWarper, + "beam_search": "beam_search" + } + _has_transformers = True +except ImportError as e: + GENERATION_TYPES = { + "top_k": None, + "top_p": None, + "beam_search": "beam_search" + } + _has_transformers = False + + +@dataclass +class MultimodalCfg(CLIPTextCfg): + mlp_ratio: int = 4 + dim_head: int = 64 + heads: int = 8 + n_queries: int = 256 + attn_pooler_heads: int = 8 + + +def _build_text_decoder_tower( + embed_dim, + multimodal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + decoder = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return decoder + + +class CoCa(nn.Module): + def __init__( + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0, + ): + super().__init__() + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + + self.text = _build_text_tower( + embed_dim=embed_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size + ) + + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.text_decoder = _build_text_decoder_tower( + vocab_size, + multimodal_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.pad_id = pad_id + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + self.text_decoder.set_grad_checkpointing(enable) + + def _encode_image(self, images, normalize=True): + image_latent, tokens_embs = self.visual(images) + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + return image_latent, tokens_embs + + def _encode_text(self, text, normalize=True, embed_cls=True): + text = text[:, :-1] if embed_cls else text # make space for CLS token + text_latent, token_emb = self.text(text) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + return text_latent, token_emb + + def encode_image(self, images, normalize=True): + image_latent, _ = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize=True, embed_cls=True): + text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) + return text_latent + + def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): + text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + if image_latent is None or image_embs is None: + image_latent, image_embs = self._encode_image(image) + + # TODO: add assertion to avoid bugs? + labels = text[:, -token_embs.shape[1]:] + + logits = self.text_decoder(image_embs, token_embs) + return { + "image_features": image_latent, + "text_features": text_latent, + "logits": logits, + "labels": labels, + "logit_scale": self.logit_scale.exp() + } + + def generate( + self, + image, + text=None, + seq_len=30, + max_seq_len=77, + temperature=1., + generation_type="beam_search", + top_p=0.1, # keep tokens in the 1 - top_p quantile + top_k=1, # keeps the top_k most probable tokens + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + repetition_penalty=1.0, + fixed_output_length=False # if True output.shape == (batch_size, seq_len) + ): + # taking many ideas and components from HuggingFace GenerationMixin + # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation + assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." + assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" + + with torch.no_grad(): + sot_token_id = 49406 if sot_token_id is None else sot_token_id + eos_token_id = 49407 if eos_token_id is None else eos_token_id + pad_token_id = self.pad_id if pad_token_id is None else pad_token_id + logit_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(min_seq_len, eos_token_id), + RepetitionPenaltyLogitsProcessor(repetition_penalty), + ] + ) + + if stopping_criteria is None: + stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] + + stopping_criteria = StoppingCriteriaList( + stopping_criteria + ) + + device = image.device + + if generation_type == "beam_search": + output = self._generate_beamsearch( + image_inputs = image, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + sot_token_id=sot_token_id, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + min_seq_len=min_seq_len, + stopping_criteria=stopping_criteria, + logit_processor=logit_processor, + ) + if fixed_output_length and output.shape[1] < seq_len: + return torch.cat( + (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), + dim=1 + ) + return output + + elif generation_type == "top_p": + logit_warper = GENERATION_TYPES[generation_type](top_p) + elif generation_type == "top_k": + logit_warper = GENERATION_TYPES[generation_type](top_k) + else: + raise ValueError( + f"generation_type has to be one of " + f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." + ) + + image_latent, image_embs = self._encode_image(image) + + if text is None: + text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id + + was_training = self.training + num_dims = len(text.shape) + + if num_dims == 1: + text = text[None, :] + + cur_len = text.shape[1] + self.eval() + out = text + + while True: + x = out[:, -max_seq_len:] + cur_len = x.shape[1] + logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] + mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) + sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id + + if mask.all(): + if not fixed_output_length: + break + else: + logits = logits[~mask, :] + filtered_logits = logit_processor(x[~mask, :], logits) + filtered_logits = logit_warper(x[~mask, :], filtered_logits) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + if (cur_len + 1 == seq_len): + sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id + else: + sample[~mask, :] = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + cur_len += 1 + + if stopping_criteria(out, None): + break + + if num_dims == 1: + out = out.squeeze(0) + + self.train(was_training) + return out + + def _generate_beamsearch( + self, + image_inputs, + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + logit_processor=None, + logit_warper=None, + ): + device = image_inputs.device + batch_size = image_inputs.shape[0] + image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) + image_latent, image_embs = self._encode_image(image_inputs) + + input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) + input_ids = input_ids * sot_token_id + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=device, + num_beam_groups=num_beam_groups, + ) + # instantiate logits processors + logits_processor = ( + LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) + if logit_processor is None + else logit_processor + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_beam_size, cur_len = input_ids.shape + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while True: + + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) + outputs = self( + model_inputs['images'], + model_inputs['text'], + embed_cls=False, + image_latent=image_latent, + image_embs=image_embs + ) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of currentg group only + next_token_logits = outputs['logits'][batch_group_indices, -1, :] + vocab_size = next_token_logits.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + # increase cur_len + cur_len = cur_len + 1 + if beam_scorer.is_done or stopping_criteria(input_ids, None): + break + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + ) + return sequence_outputs['sequences'] + + +def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = None + return { + "text": input_ids, + "images": image_inputs, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask, + } diff --git a/ext/open_clip/constants.py b/ext/open_clip/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee --- /dev/null +++ b/ext/open_clip/constants.py @@ -0,0 +1,2 @@ +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/ext/open_clip/factory.py b/ext/open_clip/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..dce3e9fbb089804edb9f63775ccb2d832cab2500 --- /dev/null +++ b/ext/open_clip/factory.py @@ -0,0 +1,387 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ + resize_pos_embed, get_cast_dtype +from .coca_model import CoCa +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .openai import load_openai_model +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ + list_pretrained_tags_by_model, download_pretrained_from_hf +from .transform import image_transform, AugmentationCfg +from .tokenizer import HFTokenizer, tokenize + + +HF_HUB_PREFIX = 'hf-hub:' +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = ('.json',) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f'*{ext}')) + + for cf in config_files: + with open(cf, 'r') as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """ enumerate available model architectures based on config files """ + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """ add model config path or file and update registry """ + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + +def get_tokenizer(model_name): + if model_name.startswith(HF_HUB_PREFIX): + tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) + else: + config = get_model_config(model_name) + tokenizer = HFTokenizer( + config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize + return tokenizer + + +def load_state_dict(checkpoint_path: str, map_location='cpu'): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + return state_dict + + +def load_checkpoint(model, checkpoint_path, strict=True): + state_dict = load_state_dict(checkpoint_path) + # detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) + resize_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + require_pretrained: bool = False, + logger: logging.Logger = logging, +): + has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) + if has_hf_hub_prefix: + model_id = model_name[len(HF_HUB_PREFIX):] + checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) + + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + pretrained_cfg = config['preprocess_cfg'] + model_cfg = config['model_cfg'] + else: + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + checkpoint_path = None + pretrained_cfg = {} + model_cfg = None + + if isinstance(device, str): + device = torch.device(device) + + if pretrained and pretrained.lower() == 'openai': + logger.info(f'Loading pretrained {model_name} from OpenAI.') + model = load_openai_model( + model_name, + precision=precision, + device=device, + cache_dir=cache_dir, + ) + else: + model_cfg = model_cfg or get_model_config(model_name) + if model_cfg is not None: + logger.info(f'Loaded {model_name} model config.') + else: + logger.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size + + is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) + if pretrained_image: + if is_timm_model: + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True + else: + assert False, 'pretrained image towers currently only supported for timm models' + + # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes + cast_dtype = get_cast_dtype(precision) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + + if custom_text: + if is_hf_model: + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf + if "coca" in model_name: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + if precision in ("fp16", "bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + # manual mixed precision that matches original OpenAI behaviour + if is_timm_model: + # FIXME this is a bit janky, create timm based model in low-precision and + # then cast only LayerNormFp32 instances back to float32 so they don't break. + # Why? The convert_weights_to_lp fn only works with native models. + model.to(device=device, dtype=dtype) + from .transformer import LayerNormFp32 + def _convert_ln(m): + if isinstance(m, LayerNormFp32): + m.weight.data = m.weight.data.to(torch.float32) + m.bias.data = m.bias.data.to(torch.float32) + model.apply(_convert_ln) + else: + model.to(device=device) + convert_weights_to_lp(model, dtype=dtype) + elif precision in ("pure_fp16", "pure_bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model.to(device=device, dtype=dtype) + else: + model.to(device=device) + + pretrained_loaded = False + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logger.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logger.warning(error_str) + raise RuntimeError(error_str) + pretrained_loaded = True + elif has_hf_hub_prefix: + logger.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + pretrained_loaded = True + + if require_pretrained and not pretrained_loaded: + # callers of create_model_from_pretrained always expect pretrained weights + raise RuntimeError( + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + + # set image / mean metadata from pretrained_cfg if available, or use default + model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN + model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD + + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + + if jit: + model = torch.jit.script(model) + + return model + + +def create_loss(args): + if args.distill: + return DistillClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + elif "coca" in args.model.lower(): + return CoCaLoss( + caption_loss_weight=args.coca_caption_loss_weight, + clip_loss_weight=args.coca_contrastive_loss_weight, + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + return ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + + +def create_model_and_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + logger: logging.Logger = logging, +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_patch_dropout=force_patch_dropout, + force_image_size=force_image_size, + pretrained_image=pretrained_image, + pretrained_hf=pretrained_hf, + cache_dir=cache_dir, + output_dict=output_dict, + logger=logger, + ) + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std, + aug_cfg=aug_cfg, + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess_train, preprocess_val + + +def create_model_from_pretrained( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + return_transform: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, + logger: logging.Logger = logging, +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_image_size=force_image_size, + cache_dir=cache_dir, + require_pretrained=True, + logger=logger, + ) + + if not return_transform: + return model + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess diff --git a/ext/open_clip/generation_utils.py b/ext/open_clip/generation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ext/open_clip/hf_configs.py b/ext/open_clip/hf_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..13c9bfd8c660eac59f1fbc1912b9fccc9c0c625a --- /dev/null +++ b/ext/open_clip/hf_configs.py @@ -0,0 +1,56 @@ +# HF architecture dict: +arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/roberta#roberta + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 + "mt5": { + "config_names": { + # unlimited seqlen + # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 + "context_length": "", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "num_heads", + "layers": "num_layers", + "layer_attr": "block", + "token_embeddings_attr": "embed_tokens" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/bert + "bert": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + }, + "pooler": "cls_pooler", + }, +} diff --git a/ext/open_clip/hf_model.py b/ext/open_clip/hf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..08dbdbcde02b550ca765ca9bcb0b667be2c0443d --- /dev/null +++ b/ext/open_clip/hf_model.py @@ -0,0 +1,193 @@ +""" huggingface model adapter + +Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. +""" +import re + +import torch +import torch.nn as nn +from torch import TensorType + +try: + import transformers + from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ + BaseModelOutputWithPoolingAndCrossAttentions +except ImportError as e: + transformers = None + + + class BaseModelOutput: + pass + + + class PretrainedConfig: + pass + +from .hf_configs import arch_dict + + +# utils +def _camel2snake(s): + return re.sub(r'(? torch.Tensor: + # calculated ground-truth and cache if enabled + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + return labels + + def get_logits(self, image_features, text_features, logit_scale): + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, text_features, + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + return logits_per_image, logits_per_text + + def forward(self, image_features, text_features, logit_scale, output_dict=False): + device = image_features.device + logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) + + labels = self.get_ground_truth(device, logits_per_image.shape[0]) + + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + return {"contrastive_loss": total_loss} if output_dict else total_loss + + +class CoCaLoss(ClipLoss): + def __init__( + self, + caption_loss_weight, + clip_loss_weight, + pad_id=0, # pad_token for open_clip custom tokenizer + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod + ) + + self.clip_loss_weight = clip_loss_weight + self.caption_loss_weight = caption_loss_weight + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + + def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): + + clip_loss = torch.tensor(0) + + if self.clip_loss_weight: + clip_loss = super().forward(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + + caption_loss = self.caption_loss( + logits.permute(0, 2, 1), + labels, + ) + caption_loss = caption_loss * self.caption_loss_weight + + if output_dict: + return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} + + return clip_loss, caption_loss + + +class DistillClipLoss(ClipLoss): + + def dist_loss(self, teacher_logits, student_logits): + return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) + + def forward( + self, + image_features, + text_features, + logit_scale, + dist_image_features, + dist_text_features, + dist_logit_scale, + output_dict=False, + ): + logits_per_image, logits_per_text = \ + self.get_logits(image_features, text_features, logit_scale) + + dist_logits_per_image, dist_logits_per_text = \ + self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) + + labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) + + contrastive_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + distill_loss = ( + self.dist_loss(dist_logits_per_image, logits_per_image) + + self.dist_loss(dist_logits_per_text, logits_per_text) + ) / 2 + + if output_dict: + return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} + + return contrastive_loss, distill_loss diff --git a/ext/open_clip/model.py b/ext/open_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f85b68ba23117cb65d082cf5cd4cf7528bab4619 --- /dev/null +++ b/ext/open_clip/model.py @@ -0,0 +1,473 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +from dataclasses import dataclass +import logging +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from .hf_model import HFTextEncoder +from .modified_resnet import ModifiedResNet +from .timm_model import TimmModel +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer +from .utils import to_2tuple + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer + n_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling + output_tokens: bool = False + + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + timm_drop: float = 0. # head dropout + timm_drop_path: Optional[float] = None # backbone stochastic depth + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + hf_model_pretrained: bool = True + proj: str = 'mlp' + pooler_type: str = 'mean_pooler' + embed_cls: bool = False + pad_id: int = 0 + output_tokens: bool = False + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def get_input_dtype(precision: str): + input_dtype = None + if precision in ('bf16', 'pure_bf16'): + input_dtype = torch.bfloat16 + elif precision in ('fp16', 'pure_fp16'): + input_dtype = torch.float16 + return input_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + proj_bias=vision_cfg.timm_proj_bias, + drop=vision_cfg.timm_drop, + drop_path=vision_cfg.timm_drop_path, + patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, + embed_dim=embed_dim, + image_size=vision_cfg.image_size, + ) + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width, + ) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + input_patchnorm=vision_cfg.input_patchnorm, + global_average_pool=vision_cfg.global_average_pool, + attentional_pool=vision_cfg.attentional_pool, + n_queries=vision_cfg.n_queries, + attn_pooler_heads=vision_cfg.attn_pooler_heads, + output_tokens=vision_cfg.output_tokens, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = HFTextEncoder( + text_cfg.hf_model_name, + output_dim=embed_dim, + proj=text_cfg.proj, + pooler_type=text_cfg.pooler_type, + pretrained=text_cfg.hf_model_pretrained, + output_tokens=text_cfg.output_tokens, + ) + else: + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + embed_cls=text_cfg.embed_cls, + output_tokens=text_cfg.output_tokens, + pad_id=text_cfg.pad_id, + act_layer=act_layer, + norm_layer=norm_layer, + ) + return text + + +class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.context_length = text.context_length + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +class CustomTextCLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.context_length = self.text.context_length + self.vocab_size = self.text.vocab_size + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + self.text.lock(unlocked_layers, freeze_layer_norm) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + if isinstance(l, (CLIP, TextTransformer)): + # convert text nn.Parameter projections + attr = getattr(l, "text_projection", None) + if attr is not None: + attr.data = attr.data.to(dtype) + + if isinstance(l, VisionTransformer): + # convert vision nn.Parameter projections + attr = getattr(l, "proj", None) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if 'text_projection' in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + )): + k = 'text.' + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg( + context_length=context_length, + vocab_size=vocab_size, + width=transformer_width, + heads=transformer_heads, + layers=transformer_layers, + ) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device('cpu')): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,) + )) + model.visual.image_size = image_size + return model + + +def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + antialias=antialias, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed diff --git a/ext/open_clip/model_configs/EVA01-g-14-plus.json b/ext/open_clip/model_configs/EVA01-g-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..73f46a71e664fce987218b8eb48903e7bd895f41 --- /dev/null +++ b/ext/open_clip/model_configs/EVA01-g-14-plus.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva_giant_patch14_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/EVA01-g-14.json b/ext/open_clip/model_configs/EVA01-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..9d0e80f290d9491b7c46fafd576201b1258165aa --- /dev/null +++ b/ext/open_clip/model_configs/EVA01-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva_giant_patch14_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/EVA02-B-16.json b/ext/open_clip/model_configs/EVA02-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..3f92357287e1f6600da1e7f391cb6370d7f66de4 --- /dev/null +++ b/ext/open_clip/model_configs/EVA02-B-16.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_base_patch16_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/EVA02-E-14-plus.json b/ext/open_clip/model_configs/EVA02-E-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..e250c2a404c86ff168c54cfcf71bc2492be1b74c --- /dev/null +++ b/ext/open_clip/model_configs/EVA02-E-14-plus.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_enormous_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/EVA02-E-14.json b/ext/open_clip/model_configs/EVA02-E-14.json new file mode 100644 index 0000000000000000000000000000000000000000..4b6648e25092b151a9095e0a66956c7ebf835b16 --- /dev/null +++ b/ext/open_clip/model_configs/EVA02-E-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_enormous_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/EVA02-L-14-336.json b/ext/open_clip/model_configs/EVA02-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..2bb07f3c082fd88c4e86131b272163aaacfaef9e --- /dev/null +++ b/ext/open_clip/model_configs/EVA02-L-14-336.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "timm_model_name": "eva02_large_patch14_clip_336", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/EVA02-L-14.json b/ext/open_clip/model_configs/EVA02-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b4c7f377bc543aa92a145358f2630a58ae9be989 --- /dev/null +++ b/ext/open_clip/model_configs/EVA02-L-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "timm_model_name": "eva02_large_patch14_clip_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/RN101-quickgelu.json b/ext/open_clip/model_configs/RN101-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/ext/open_clip/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/RN101.json b/ext/open_clip/model_configs/RN101.json new file mode 100644 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/ext/open_clip/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/RN50-quickgelu.json b/ext/open_clip/model_configs/RN50-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/ext/open_clip/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/ext/open_clip/model_configs/RN50.json b/ext/open_clip/model_configs/RN50.json new file mode 100644 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/ext/open_clip/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/RN50x16.json b/ext/open_clip/model_configs/RN50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/ext/open_clip/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/RN50x4.json b/ext/open_clip/model_configs/RN50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/ext/open_clip/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/RN50x64.json b/ext/open_clip/model_configs/RN50x64.json new file mode 100644 index 0000000000000000000000000000000000000000..f5aaa2ee3de21ddb03cbd12766a3419bf34898c7 --- /dev/null +++ b/ext/open_clip/model_configs/RN50x64.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 448, + "layers": [ + 3, + 15, + 36, + 10 + ], + "width": 128, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-B-16-plus-240.json b/ext/open_clip/model_configs/ViT-B-16-plus-240.json new file mode 100644 index 0000000000000000000000000000000000000000..5bbd12bcd01f64d6d0a0aa8316b129327a0d169a --- /dev/null +++ b/ext/open_clip/model_configs/ViT-B-16-plus-240.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 240, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-B-16-plus.json b/ext/open_clip/model_configs/ViT-B-16-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..5dc1e09baccef2b15055c1bffeb9903e760101c6 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-B-16-plus.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-B-16.json b/ext/open_clip/model_configs/ViT-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/ext/open_clip/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-B-32-plus-256.json b/ext/open_clip/model_configs/ViT-B-32-plus-256.json new file mode 100644 index 0000000000000000000000000000000000000000..2f09c857de9a4c01ae51297a7e2451984879f9de --- /dev/null +++ b/ext/open_clip/model_configs/ViT-B-32-plus-256.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 256, + "layers": 12, + "width": 896, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-B-32-quickgelu.json b/ext/open_clip/model_configs/ViT-B-32-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-B-32.json b/ext/open_clip/model_configs/ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/ext/open_clip/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-H-14.json b/ext/open_clip/model_configs/ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-H-14.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-H-16.json b/ext/open_clip/model_configs/ViT-H-16.json new file mode 100644 index 0000000000000000000000000000000000000000..588485455fdf8193ec16474450b94e31c91ea93c --- /dev/null +++ b/ext/open_clip/model_configs/ViT-H-16.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-L-14-280.json b/ext/open_clip/model_configs/ViT-L-14-280.json new file mode 100644 index 0000000000000000000000000000000000000000..2262deaefa82792d35d73c0d7c8e620525092581 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-L-14-280.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 280, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-L-14-336.json b/ext/open_clip/model_configs/ViT-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..8d1f74c2639c3a3705df9865b9c08215675ddc97 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-L-14-336.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-L-14.json b/ext/open_clip/model_configs/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-L-16-320.json b/ext/open_clip/model_configs/ViT-L-16-320.json new file mode 100644 index 0000000000000000000000000000000000000000..fc2d13ca9ec7f0b56a886ddaf66c4a7ba7a442ba --- /dev/null +++ b/ext/open_clip/model_configs/ViT-L-16-320.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 320, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-L-16.json b/ext/open_clip/model_configs/ViT-L-16.json new file mode 100644 index 0000000000000000000000000000000000000000..82a1cedfa290adacbbdc02bc5d589734c22d41d3 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-L-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-M-16-alt.json b/ext/open_clip/model_configs/ViT-M-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..1a317aad8e02d9c26d2decc7cc49a18dfdf9e0d8 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-M-16-alt.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16, + "ls_init_value": 1e-4 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-M-16.json b/ext/open_clip/model_configs/ViT-M-16.json new file mode 100644 index 0000000000000000000000000000000000000000..f2f3225a46e09237730a151d161f70c86b985172 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-M-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-M-32-alt.json b/ext/open_clip/model_configs/ViT-M-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..fd222aeac0f582ef6a1a33f1b3fec70a5b386ac0 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-M-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-M-32.json b/ext/open_clip/model_configs/ViT-M-32.json new file mode 100644 index 0000000000000000000000000000000000000000..4f718642821035d9776d1e006817d65ede074366 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-M-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-S-16-alt.json b/ext/open_clip/model_configs/ViT-S-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..a8c056555e4da3ba0d1475a61fc316362ecce76f --- /dev/null +++ b/ext/open_clip/model_configs/ViT-S-16-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-S-16.json b/ext/open_clip/model_configs/ViT-S-16.json new file mode 100644 index 0000000000000000000000000000000000000000..1d8504e59658803f3093e5b05de45f30a09b8185 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-S-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-S-32-alt.json b/ext/open_clip/model_configs/ViT-S-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..e1dfdec9824df09a2010e991ccfa1d9ee2f45807 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-S-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-S-32.json b/ext/open_clip/model_configs/ViT-S-32.json new file mode 100644 index 0000000000000000000000000000000000000000..9b8b4191b268de267268cfcb90fc01c6b9df07d8 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-S-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-bigG-14.json b/ext/open_clip/model_configs/ViT-bigG-14.json new file mode 100644 index 0000000000000000000000000000000000000000..2cfba479a2e8f3737e71ce240732bf3bc743d8b7 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-bigG-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-e-14.json b/ext/open_clip/model_configs/ViT-e-14.json new file mode 100644 index 0000000000000000000000000000000000000000..91a0fe14d25a107fb8ec48dd7faae313fd26ed7b --- /dev/null +++ b/ext/open_clip/model_configs/ViT-e-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 56, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.5715, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 36 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/ViT-g-14.json b/ext/open_clip/model_configs/ViT-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..8c4b7325cc75b6112be7107d36ae2cb5762d9091 --- /dev/null +++ b/ext/open_clip/model_configs/ViT-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/coca_ViT-B-32.json b/ext/open_clip/model_configs/coca_ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..7e7eb520a6a0096e5602d509ecd6186e278f4725 --- /dev/null +++ b/ext/open_clip/model_configs/coca_ViT-B-32.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "attn_pooler_heads": 8 + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/coca_ViT-L-14.json b/ext/open_clip/model_configs/coca_ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3d5ca4ca2338540f06852df5ff35ea6277e64555 --- /dev/null +++ b/ext/open_clip/model_configs/coca_ViT-L-14.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "attn_pooler_heads": 12 + }, + "custom_text": true +} diff --git a/ext/open_clip/model_configs/coca_base.json b/ext/open_clip/model_configs/coca_base.json new file mode 100644 index 0000000000000000000000000000000000000000..cf8c6cecb78a49d7e7140145a0307cbd561077c2 --- /dev/null +++ b/ext/open_clip/model_configs/coca_base.json @@ -0,0 +1,31 @@ +{ + "embed_dim": 512, + "multimodal_cfg": { + "width": 768, + "context_length": 76, + "vocab_size": 64000, + "mlp_ratio": 4, + "layers": 12, + "dim_head": 64, + "heads": 12, + "n_queries": 256, + "attn_pooler_heads": 8 + }, + "vision_cfg": { + "image_size": 288, + "layers": 12, + "width": 768, + "patch_size": 18, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 64000, + "layers": 12, + "heads": 12, + "width": 768, + "embed_cls": true, + "output_tokens": true + }, + "custom_text": true +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/coca_roberta-ViT-B-32.json b/ext/open_clip/model_configs/coca_roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..fb46354b95a17a46d7fcfd9d504e917ee6c1608c --- /dev/null +++ b/ext/open_clip/model_configs/coca_roberta-ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "output_tokens": true + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "linear", + "width": 768, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "width": 768, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} diff --git a/ext/open_clip/model_configs/convnext_base.json b/ext/open_clip/model_configs/convnext_base.json new file mode 100644 index 0000000000000000000000000000000000000000..bb6dba181d950ea5081155c90d47e72c94816b80 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_base.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_base_w.json b/ext/open_clip/model_configs/convnext_base_w.json new file mode 100644 index 0000000000000000000000000000000000000000..82ea7ae3659e5514f37ff982f0ab1141dff4bd18 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_base_w.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_base_w_320.json b/ext/open_clip/model_configs/convnext_base_w_320.json new file mode 100644 index 0000000000000000000000000000000000000000..0a07c4e16abaa4015ecc5f82ec845de16e1f9d88 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_base_w_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_large.json b/ext/open_clip/model_configs/convnext_large.json new file mode 100644 index 0000000000000000000000000000000000000000..c4a1fea73dbead71c218a0e74b9b15f9b252e3ef --- /dev/null +++ b/ext/open_clip/model_configs/convnext_large.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_large_d.json b/ext/open_clip/model_configs/convnext_large_d.json new file mode 100644 index 0000000000000000000000000000000000000000..ae8fed21b58e1a6a411daf8b792ee50f0ab42346 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_large_d.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_large_d_320.json b/ext/open_clip/model_configs/convnext_large_d_320.json new file mode 100644 index 0000000000000000000000000000000000000000..54c3df36a6f56ace0b12ada24c13058de96feed8 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_large_d_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_small.json b/ext/open_clip/model_configs/convnext_small.json new file mode 100644 index 0000000000000000000000000000000000000000..3592c2a5cd21aae8d2544931773cf7603f67ea28 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_small.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_small", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_tiny.json b/ext/open_clip/model_configs/convnext_tiny.json new file mode 100644 index 0000000000000000000000000000000000000000..ad11470f5ec40ffec771096971ce58d3d5b9249b --- /dev/null +++ b/ext/open_clip/model_configs/convnext_tiny.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_tiny", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_xlarge.json b/ext/open_clip/model_configs/convnext_xlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..2a909965932eef994177c829fefc2bdc1c219b3f --- /dev/null +++ b/ext/open_clip/model_configs/convnext_xlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 20 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_xxlarge.json b/ext/open_clip/model_configs/convnext_xxlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..23a55a681c346d1a315d8a163c1cb6ad495e6a91 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_xxlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/convnext_xxlarge_320.json b/ext/open_clip/model_configs/convnext_xxlarge_320.json new file mode 100644 index 0000000000000000000000000000000000000000..ac5134ca12cbaa97772cde059270d345386a74c7 --- /dev/null +++ b/ext/open_clip/model_configs/convnext_xxlarge_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/mt5-base-ViT-B-32.json b/ext/open_clip/model_configs/mt5-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..58cad89cf0f446bbe15e4e25b1ac43424a828017 --- /dev/null +++ b/ext/open_clip/model_configs/mt5-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "google/mt5-base", + "hf_tokenizer_name": "google/mt5-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/ext/open_clip/model_configs/mt5-xl-ViT-H-14.json b/ext/open_clip/model_configs/mt5-xl-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b432810777ba7269dbb0e89edfe65cdd27e7d255 --- /dev/null +++ b/ext/open_clip/model_configs/mt5-xl-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "google/mt5-xl", + "hf_tokenizer_name": "google/mt5-xl", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/ext/open_clip/model_configs/roberta-ViT-B-32.json b/ext/open_clip/model_configs/roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..ed687d472a73bb2ac96025f355f80437ab14c260 --- /dev/null +++ b/ext/open_clip/model_configs/roberta-ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/ext/open_clip/model_configs/swin_base_patch4_window7_224.json b/ext/open_clip/model_configs/swin_base_patch4_window7_224.json new file mode 100644 index 0000000000000000000000000000000000000000..bd6820f0cf2aa655e0a2723287f4b78895a58e6a --- /dev/null +++ b/ext/open_clip/model_configs/swin_base_patch4_window7_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "swin_base_patch4_window7_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/vit_medium_patch16_gap_256.json b/ext/open_clip/model_configs/vit_medium_patch16_gap_256.json new file mode 100644 index 0000000000000000000000000000000000000000..8843eaf08cad16c3e7b5f496fd650715c9573f65 --- /dev/null +++ b/ext/open_clip/model_configs/vit_medium_patch16_gap_256.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_medium_patch16_gap_256", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json b/ext/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json new file mode 100644 index 0000000000000000000000000000000000000000..ed217b202d5e6071c5307f4547c97ff4cfe2abd1 --- /dev/null +++ b/ext/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_relpos_medium_patch16_cls_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/ext/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json b/ext/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..751bccc2c6fc41bc4ff20182de88d86739d518d9 --- /dev/null +++ b/ext/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-base", + "hf_tokenizer_name": "xlm-roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/ext/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json b/ext/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..31f271faa9bbb7a9da53900b483a4c00a16f3c4a --- /dev/null +++ b/ext/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-large", + "hf_tokenizer_name": "xlm-roberta-large", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/ext/open_clip/modified_resnet.py b/ext/open_clip/modified_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8d3aeda91ecb394303becbbfccc8acd8cddcd9 --- /dev/null +++ b/ext/open_clip/modified_resnet.py @@ -0,0 +1,181 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from .utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0., + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/ext/open_clip/openai.py b/ext/open_clip/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2c0235245c2e4f1217b3b2bfaf2acf78e74981 --- /dev/null +++ b/ext/open_clip/openai.py @@ -0,0 +1,90 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import List, Optional, Union + +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_models_by_tag('openai') + + +def load_openai_model( + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + cache_dir: Optional[str] = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. + device : Union[str, torch.device] + The device to put the loaded model + cache_dir : Optional[str] + The directory to cache the downloaded model weights + + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = 'fp32' if device == 'cpu' else 'fp16' + + if get_pretrained_url(name, 'openai'): + model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location="cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + state_dict = torch.load(model_path, map_location="cpu") + + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + # FIXME support pure fp16/bf16 precision modes + if precision != 'fp16': + model.float() + if precision == 'bf16': + # for bf16, convert back to low-precision + convert_weights_to_lp(model, dtype=torch.bfloat16) + + # add mean / std attributes for consistency with OpenCLIP models + model.visual.image_mean = OPENAI_DATASET_MEAN + model.visual.image_std = OPENAI_DATASET_STD + return model diff --git a/ext/open_clip/pretrained.py b/ext/open_clip/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..1465a2325652be7e7a1d7563698e38b9ec408cc6 --- /dev/null +++ b/ext/open_clip/pretrained.py @@ -0,0 +1,427 @@ +import hashlib +import os +import urllib +import warnings +from functools import partial +from typing import Dict, Union + +from tqdm import tqdm + +from .version import __version__ + +try: + from huggingface_hub import hf_hub_download + hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url='', hf_hub='', mean=None, std=None): + return dict( + url=url, + hf_hub=hf_hub, + mean=mean, + std=std, + ) + + +_RN50 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN50_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN101 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN101_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN50x4 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), +) + +_RN50x16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), +) + +_RN50x64 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), +) + +_VITB32 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + laion2b_e16=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), + # DataComp-M models + datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'), + commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'), + commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'), + commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'), + commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'), + commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'), + commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'), + # DataComp-S models + datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'), + commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'), + commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'), + commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'), + commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), + commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), + commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), +) + +_VITB32_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), +) + +_VITB16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), + # DataComp-L models + datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'), + commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'), + commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'), + commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'), + commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'), + commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), + commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), +) + +_VITL14 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + laion2b_s32b_b82k=_pcfg( + hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + # DataComp-XL models + datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'), + commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), + commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), + commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), +) + +_VITL14_336 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), +) + +_VITH14 = dict( + laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), +) + +_VITbigG14 = dict( + laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), +) + +_robertaViTB32 = dict( + laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), +) + +_xlmRobertaBaseViTB32 = dict( + laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), +) + +_xlmRobertaLargeFrozenViTH14 = dict( + frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), +) + +_convnext_base = dict( + laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), +) + +_convnext_base_w = dict( + laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), + laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), +) + +_convnext_base_w_320 = dict( + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), + laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), +) + +_convnext_large_d = dict( + laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), +) + +_convnext_large_d_320 = dict( + laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), + laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), +) + +_convnext_xxlarge = dict( + laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), + laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), + laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), +) + +_coca_VITB32 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') +) + +_coca_VITL14 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') +) + + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "RN50x64": _RN50x64, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-B-16-plus-240": _VITB16_PLUS_240, + "ViT-L-14": _VITL14, + "ViT-L-14-336": _VITL14_336, + "ViT-H-14": _VITH14, + "ViT-g-14": _VITg14, + "ViT-bigG-14": _VITbigG14, + "roberta-ViT-B-32": _robertaViTB32, + "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, + "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, + "convnext_base": _convnext_base, + "convnext_base_w": _convnext_base_w, + "convnext_base_w_320": _convnext_base_w_320, + "convnext_large_d": _convnext_large_d, + "convnext_large_d_320": _convnext_large_d_320, + "convnext_xxlarge": _convnext_xxlarge, + "coca_ViT-B-32": _coca_VITB32, + "coca_ViT-L-14": _coca_VITL14, + "EVA01-g-14": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt + laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), + ), + "EVA01-g-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt + merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'), + ), + "EVA02-B-16": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt + merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'), + ), + "EVA02-L-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt + merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'), + ), + "EVA02-L-14-336": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt + merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'), + ), + "EVA02-E-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt + laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'), + ), + "EVA02-E-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt + laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), + ) +} + + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace('-', '_') + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get('url', '') + + +def download_pretrained_from_url( + url: str, + cache_dir: Union[str, None] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def download_pretrained_from_hf( + model_id: str, + filename: str = 'open_clip_pytorch_model.bin', + revision=None, + cache_dir: Union[str, None] = None, +): + has_hf_hub(True) + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file + + +def download_pretrained( + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, +): + target = '' + if not cfg: + return target + + download_url = cfg.get('url', '') + download_hf_hub = cfg.get('hf_hub', '') + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists + download_url = '' + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target diff --git a/ext/open_clip/push_to_hf_hub.py b/ext/open_clip/push_to_hf_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6271da1d35e36ea22e92d339dc9465d0793249 --- /dev/null +++ b/ext/open_clip/push_to_hf_hub.py @@ -0,0 +1,280 @@ +import argparse +import json +import os +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Tuple, Union + +import torch + +try: + from huggingface_hub import ( + create_repo, + get_hf_file_metadata, + hf_hub_download, + hf_hub_url, + repo_type_and_id_from_hf_id, + upload_folder, + list_repo_files, + ) + from huggingface_hub.utils import EntryNotFoundError + _has_hf_hub = True +except ImportError: + _has_hf_hub = False + +try: + import safetensors.torch + _has_safetensors = True +except ImportError: + _has_safetensors = False + +from .factory import create_model_from_pretrained, get_model_config, get_tokenizer +from .tokenizer import HFTokenizer + +# Default name for a weights file hosted on the Huggingface Hub. +HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl +HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version +HF_CONFIG_NAME = 'open_clip_config.json' + +def save_config_for_hf( + model, + config_path: str, + model_config: Optional[dict] +): + preprocess_cfg = { + 'mean': model.visual.image_mean, + 'std': model.visual.image_std, + } + hf_config = { + 'model_cfg': model_config, + 'preprocess_cfg': preprocess_cfg, + } + + with config_path.open('w') as f: + json.dump(hf_config, f, indent=2) + + +def save_for_hf( + model, + tokenizer: HFTokenizer, + model_config: dict, + save_directory: str, + safe_serialization: Union[bool, str] = False, + skip_weights : bool = False, +): + config_filename = HF_CONFIG_NAME + + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + + if not skip_weights: + tensors = model.state_dict() + if safe_serialization is True or safe_serialization == "both": + assert _has_safetensors, "`pip install safetensors` to use .safetensors" + safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) + if safe_serialization is False or safe_serialization == "both": + torch.save(tensors, save_directory / HF_WEIGHTS_NAME) + + tokenizer.save_pretrained(save_directory) + + config_path = save_directory / config_filename + save_config_for_hf(model, config_path, model_config=model_config) + + +def push_to_hf_hub( + model, + tokenizer, + model_config: Optional[dict], + repo_id: str, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, + safe_serialization: Union[bool, str] = False, +): + if not isinstance(tokenizer, HFTokenizer): + # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 + tokenizer = HFTokenizer('openai/clip-vit-large-patch14') + + # Create repo if it doesn't exist yet + repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) + + # Infer complete repo_id from repo_url + # Can be different from the input `repo_id` if repo_owner was implicit + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) + repo_id = f"{repo_owner}/{repo_name}" + + # Check if repo already exists and determine what needs updating + repo_exists = False + repo_files = {} + try: + repo_files = set(list_repo_files(repo_id)) + repo_exists = True + except Exception as e: + print('Repo does not exist', e) + + try: + get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) + has_readme = True + except EntryNotFoundError: + has_readme = False + + # Dump model and push to Hub + with TemporaryDirectory() as tmpdir: + # Save model weights and config. + save_for_hf( + model, + tokenizer=tokenizer, + model_config=model_config, + save_directory=tmpdir, + safe_serialization=safe_serialization, + ) + + # Add readme if it does not exist + if not has_readme: + model_card = model_card or {} + model_name = repo_id.split('/')[-1] + readme_path = Path(tmpdir) / "README.md" + readme_text = generate_readme(model_card, model_name) + readme_path.write_text(readme_text) + + # Upload model and return + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) + + +def push_pretrained_to_hf_hub( + model_name, + pretrained: str, + repo_id: str, + precision: str = 'fp32', + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, +): + model, preprocess_eval = create_model_from_pretrained( + model_name, + pretrained=pretrained, + precision=precision, + image_mean=image_mean, + image_std=image_std, + ) + + model_config = get_model_config(model_name) + assert model_config + + tokenizer = get_tokenizer(model_name) + + push_to_hf_hub( + model=model, + tokenizer=tokenizer, + model_config=model_config, + repo_id=repo_id, + commit_message=commit_message, + token=token, + revision=revision, + private=private, + create_pr=create_pr, + model_card=model_card, + safe_serialization='both', + ) + + +def generate_readme(model_card: dict, model_name: str): + readme_text = "---\n" + readme_text += "tags:\n- clip\n" + readme_text += "library_name: open_clip\n" + readme_text += "pipeline_tag: zero-shot-image-classification\n" + readme_text += f"license: {model_card.get('license', 'mit')}\n" + if 'details' in model_card and 'Dataset' in model_card['details']: + readme_text += 'datasets:\n' + readme_text += f"- {model_card['details']['Dataset'].lower()}\n" + readme_text += "---\n" + readme_text += f"# Model card for {model_name}\n" + if 'description' in model_card: + readme_text += f"\n{model_card['description']}\n" + if 'details' in model_card: + readme_text += f"\n## Model Details\n" + for k, v in model_card['details'].items(): + if isinstance(v, (list, tuple)): + readme_text += f"- **{k}:**\n" + for vi in v: + readme_text += f" - {vi}\n" + elif isinstance(v, dict): + readme_text += f"- **{k}:**\n" + for ki, vi in v.items(): + readme_text += f" - {ki}: {vi}\n" + else: + readme_text += f"- **{k}:** {v}\n" + if 'usage' in model_card: + readme_text += f"\n## Model Usage\n" + readme_text += model_card['usage'] + readme_text += '\n' + + if 'comparison' in model_card: + readme_text += f"\n## Model Comparison\n" + readme_text += model_card['comparison'] + readme_text += '\n' + + if 'citation' in model_card: + readme_text += f"\n## Citation\n" + if not isinstance(model_card['citation'], (list, tuple)): + citations = [model_card['citation']] + else: + citations = model_card['citation'] + for c in citations: + readme_text += f"```bibtex\n{c}\n```\n" + + return readme_text + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") + parser.add_argument( + "--model", type=str, help="Name of the model to use.", + ) + parser.add_argument( + "--pretrained", type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--repo-id", type=str, + help="Destination HF Hub repo-id ie 'organization/model_id'.", + ) + parser.add_argument( + "--precision", type=str, default='fp32', + ) + parser.add_argument( + '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override default image mean value of dataset') + parser.add_argument( + '--image-std', type=float, nargs='+', default=None, metavar='STD', + help='Override default image std deviation of of dataset') + args = parser.parse_args() + + print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') + + # FIXME add support to pass model_card json / template from file via cmd line + + push_pretrained_to_hf_hub( + args.model, + args.pretrained, + args.repo_id, + precision=args.precision, + image_mean=args.image_mean, # override image mean/std if trained w/ non defaults + image_std=args.image_std, + ) + + print(f'{args.model} saved.') diff --git a/ext/open_clip/timm_model.py b/ext/open_clip/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3d3f595d67cdedd142b6312d26924e8e58c67086 --- /dev/null +++ b/ext/open_clip/timm_model.py @@ -0,0 +1,149 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + try: + # old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + except ImportError: + # new timm imports >= 0.8.1 + from timm.layers import RotAttentionPool2d + from timm.layers import AttentionPool2d as AbsAttentionPool2d +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """ timm model adapter + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool='avg', + proj='linear', + proj_bias=False, + drop=0., + drop_path=None, + patch_drop=None, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + self.image_size = to_2tuple(image_size) + + # setup kwargs that may not be common across all models + timm_kwargs = {} + if drop_path is not None: + timm_kwargs['drop_path_rate'] = drop_path + if patch_drop is not None: + timm_kwargs['patch_drop_rate'] = patch_drop + + custom_pool = pool in ('abs_attn', 'rot_attn') + if not proj and not custom_pool: + # use network classifier head as projection if no proj specified and no custom pooling used + self.trunk = timm.create_model( + model_name, + num_classes=embed_dim, + global_pool=pool, + pretrained=pretrained, + **timm_kwargs, + ) + prev_chs = embed_dim + else: + self.trunk = timm.create_model( + model_name, + pretrained=pretrained, + **timm_kwargs, + ) + feat_size = self.trunk.default_cfg.get('pool_size', None) + feature_ndim = 1 if not feat_size else 2 + if custom_pool: + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool='') + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + + # Add custom pooling to head + if pool == 'abs_attn': + head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == 'rot_attn': + head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == 'linear': + head_layers['drop'] = nn.Dropout(drop) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == 'mlp': + head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) + else: + assert not proj, f'Unknown projection type {proj}.' + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """ lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/ext/open_clip/tokenizer.py b/ext/open_clip/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..23fcfcbcb4ca051ba5bba7520918693001999282 --- /dev/null +++ b/ext/open_clip/tokenizer.py @@ -0,0 +1,214 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', ''] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + """HuggingFace tokenizer wrapper""" + + def __init__(self, tokenizer_name: str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids + return input_ids diff --git a/ext/open_clip/transform.py b/ext/open_clip/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..748884a3c7cb7ece1ca521ca1dbf40bb74855007 --- /dev/null +++ b/ext/open_clip/transform.py @@ -0,0 +1,133 @@ +import warnings +from dataclasses import dataclass, asdict +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +@dataclass +class AugmentationCfg: + scale: Tuple[float, float] = (0.9, 1.0) + ratio: Optional[Tuple[float, float]] = None + color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None + interpolation: Optional[str] = None + re_prob: Optional[float] = None + re_count: Optional[int] = None + use_timm: bool = False + + +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == 'min' else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) + return img + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + if isinstance(aug_cfg, dict): + aug_cfg = AugmentationCfg(**aug_cfg) + else: + aug_cfg = aug_cfg or AugmentationCfg() + normalize = Normalize(mean=mean, std=std) + if is_train: + aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} + use_timm = aug_cfg_dict.pop('use_timm', False) + if use_timm: + from timm.data import create_transform # timm can still be optional + if isinstance(image_size, (tuple, list)): + assert len(image_size) >= 2 + input_size = (3,) + image_size[-2:] + else: + input_size = (3, image_size, image_size) + # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time + aug_cfg_dict.setdefault('interpolation', 'random') + aug_cfg_dict.setdefault('color_jitter', None) # disable by default + train_transform = create_transform( + input_size=input_size, + is_training=True, + hflip=0., + mean=mean, + std=std, + re_mode='pixel', + **aug_cfg_dict, + ) + else: + train_transform = Compose([ + RandomResizedCrop( + image_size, + scale=aug_cfg_dict.pop('scale'), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ]) + if aug_cfg_dict: + warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') + return train_transform + else: + if resize_longest_max: + transforms = [ + ResizeMaxSize(image_size, fill=fill_color) + ] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend([ + _convert_to_rgb, + ToTensor(), + normalize, + ]) + return Compose(transforms) diff --git a/ext/open_clip/transformer.py b/ext/open_clip/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a30e94664a2dd890a373eb0a0f640818836baaa --- /dev/null +++ b/ext/open_clip/transformer.py @@ -0,0 +1,726 @@ +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +from .utils import to_2tuple + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask + )[0] + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = Attention( + d_model, n_head, + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + ) + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): + return self.resblocks[0].mlp.c_fc.int8_original_dtype + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + global_average_pool: bool = False, + attentional_pool: bool = False, + n_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0., + input_patchnorm: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False + ): + super().__init__() + self.output_tokens = output_tokens + image_height, image_width = self.image_size = to_2tuple(image_size) + patch_height, patch_width = self.patch_size = to_2tuple(patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.output_dim = output_dim + + # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 + self.input_patchnorm = input_patchnorm + + if input_patchnorm: + patch_input_dim = patch_height * patch_width * 3 + self.patchnorm_pre_ln = LayerNorm(patch_input_dim) + self.conv1 = nn.Linear(patch_input_dim, width) + else: + self.patchnorm_pre_ln = nn.Identity() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.ln_pre = norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.global_average_pool = global_average_pool + if attentional_pool: + self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) + else: + self.attn_pool = None + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool: + return x.mean(dim=1), x + else: + return x[:, 0], x[:, 1:] + + def forward(self, x: torch.Tensor): + + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) + x = self.patchnorm_pre_ln(x) + x = self.conv1(x) + else: + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if self.attn_pool is not None: + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + else: + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + + if self.proj is not None: + pooled = pooled @ self.proj + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + embed_cls: bool = False, + pad_id: int = 0, + output_tokens: bool = False, + ): + super().__init__() + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pad_id = pad_id + + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + if self.cls_emb is not None: + nn.init.normal_(self.cls_emb, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def _repeat(self, t, N: int): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask + if self.cls_emb is not None: + seq_len += 1 + x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + if self.cls_emb is not None: + pooled, tokens = x[:, -1], x[:, :-1] + pooled = self.ln_final(pooled) + else: + x = self.ln_final(x) + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + + if self.text_projection is not None: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + seq_len = text_embs.shape[0] + + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) + else: + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if self.text_projection is not None: + x = x @ self.text_projection + + return x + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable diff --git a/ext/open_clip/utils.py b/ext/open_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bb0bb8868ae1f2d31493ca32b73accd6bf1d3cdb --- /dev/null +++ b/ext/open_clip/utils.py @@ -0,0 +1,89 @@ +from itertools import repeat +import collections.abc + +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) + +# Replaces all linear layers with linear_replacement +# TODO: add int8 support for other linear layers including attn and convnets +def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, include_modules, copy_weights) + + if isinstance(module, torch.nn.Linear) and name in include_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight.data.copy_(old_module.weight.data) + if model._modules[name].bias is not None: + model._modules[name].bias.data.copy_(old_module.bias) + + return model + +def convert_int8_model_to_inference_mode(model): + for m in model.modules(): + if hasattr(m, 'prepare_for_eval'): + int8_original_dtype = m.weight.dtype + m.prepare_for_eval() + m.int8_original_dtype = int8_original_dtype \ No newline at end of file diff --git a/ext/open_clip/version.py b/ext/open_clip/version.py new file mode 100644 index 0000000000000000000000000000000000000000..a910817da22d06aa0244c6d488b40d30da2bfb7e --- /dev/null +++ b/ext/open_clip/version.py @@ -0,0 +1 @@ +__version__ = '2.20.0' diff --git a/ext/open_clip/zero_shot_classifier.py b/ext/open_clip/zero_shot_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..12b58f65bb0875b164946a9ee73e938aef255382 --- /dev/null +++ b/ext/open_clip/zero_shot_classifier.py @@ -0,0 +1,110 @@ +from functools import partial +from itertools import islice +from typing import Callable, List, Optional, Sequence, Union + +import torch +import torch.nn.functional as F + + +def batched(iterable, n): + """Batch data into lists of length *n*. The last batch may be shorter. + NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl + """ + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + break + yield batch + + +def build_zero_shot_classifier( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + num_classes_per_batch: Optional[int] = 10, + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names in batches + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + num_classes_per_batch: The number of classes to batch together in each forward, all if None + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + use_format = isinstance(templates[0], str) + num_templates = len(templates) + num_classes = len(classnames) + if use_tqdm: + import tqdm + num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) + iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) + else: + iter_wrap = iter + + def _process_batch(batch_classnames): + num_batch_classes = len(batch_classnames) + texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] + texts = tokenizer(texts).to(device) + class_embeddings = F.normalize(model.encode_text(texts), dim=-1) + class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) + class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) + class_embeddings = class_embeddings.T + return class_embeddings + + with torch.no_grad(): + if num_classes_per_batch: + batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] + zeroshot_weights = torch.cat(batched_embeds, dim=1) + else: + zeroshot_weights = _process_batch(classnames) + return zeroshot_weights + + +def build_zero_shot_classifier_legacy( + model, + tokenizer, + classnames: Sequence[str], + templates: Sequence[Union[Callable, str]], + device: Union[str, torch.device] = 'cpu', + use_tqdm: bool = False, +): + """ Build zero-shot classifier weights by iterating over class names 1 by 1 + Args: + model: CLIP model instance + tokenizer: CLIP tokenizer instance + classnames: A sequence of class (label) names + templates: A sequence of callables or format() friendly strings to produce templates per class name + device: Device to use. + use_tqdm: Enable TQDM progress bar. + """ + assert isinstance(templates, Sequence) and len(templates) > 0 + assert isinstance(classnames, Sequence) and len(classnames) > 0 + if use_tqdm: + import tqdm + iter_wrap = tqdm.tqdm + else: + iter_wrap = iter + + use_format = isinstance(templates[0], str) + + with torch.no_grad(): + zeroshot_weights = [] + for classname in iter_wrap(classnames): + texts = [template.format(classname) if use_format else template(classname) for template in templates] + texts = tokenizer(texts).to(device) # tokenize + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) + + return zeroshot_weights + diff --git a/ext/open_clip/zero_shot_metadata.py b/ext/open_clip/zero_shot_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb452bbb6e27b71cff1dd27e2bb263259b9363f --- /dev/null +++ b/ext/open_clip/zero_shot_metadata.py @@ -0,0 +1,266 @@ + +OPENAI_IMAGENET_TEMPLATES = ( + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +) + + +# a much smaller subset of above prompts +# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb +SIMPLE_IMAGENET_TEMPLATES = ( + lambda c: f'itap of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'art of the {c}.', + lambda c: f'a photo of the small {c}.', +) + + +IMAGENET_CLASSNAMES = ( + "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", + "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", + "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", + "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", + "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", + "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", + "box turtle", "banded gecko", "green iguana", "Carolina anole", + "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", + "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", + "American alligator", "triceratops", "worm snake", "ring-necked snake", + "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", + "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", + "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", + "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", + "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", + "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", + "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", + "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", + "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", + "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", + "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", + "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", + "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", + "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", + "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", + "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", + "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", + "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", + "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", + "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", + "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", + "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", + "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", + "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", + "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", + "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", + "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", + "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", + "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", + "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", + "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", + "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", + "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", + "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", + "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", + "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", + "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", + "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", + "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", + "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", + "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", + "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", + "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", + "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", + "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", + "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", + "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", + "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", + "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", + "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", + "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", + "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", + "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", + "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", + "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", + "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", + "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", + "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", + "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", + "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", + "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", + "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", + "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", + "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", + "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", + "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", + "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", + "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", + "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", + "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", + "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", + "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", + "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", + "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", + "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", + "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", + "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", + "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", + "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", + "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", + "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", + "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", + "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", + "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", + "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", + "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", + "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", + "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", + "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", + "freight car", "French horn", "frying pan", "fur coat", "garbage truck", + "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", + "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", + "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", + "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", + "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", + "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", + "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", + "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", + "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", + "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", + "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", + "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", + "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", + "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", + "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", + "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", + "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", + "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", + "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", + "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", + "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", + "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", + "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", + "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", + "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", + "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", + "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", + "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", + "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", + "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", + "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", + "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", + "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", + "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", + "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", + "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", + "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", + "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", + "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", + "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", + "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", + "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", + "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", + "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", + "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", + "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", + "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", + "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", + "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", + "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", + "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", + "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", + "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", + "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", + "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", + "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", + "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", + "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", + "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", + "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", + "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", + "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", + "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", + "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", + "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper" +) + diff --git a/ext/sam/__init__.py b/ext/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd9636764f4a9e6500fac352c388063d9b629aab --- /dev/null +++ b/ext/sam/__init__.py @@ -0,0 +1,3 @@ +from .image_encoder import ImageEncoderViT +from .prompt_encoder import PromptEncoder +from .mask_decoder import MaskDecoder diff --git a/ext/sam/common.py b/ext/sam/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96 --- /dev/null +++ b/ext/sam/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/ext/sam/image_encoder.py b/ext/sam/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..66351d9d7c589be693f4b3485901d3bdfed54d4a --- /dev/null +++ b/ext/sam/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/ext/sam/mask_decoder.py b/ext/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..967fbdfd3024aed01aa604a0420cb3240720fcbc --- /dev/null +++ b/ext/sam/mask_decoder.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + +from .transformer import TwoWayTransformer + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + with_iou: bool = True + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = TwoWayTransformer( + depth=2, + embedding_dim=transformer_dim, + mlp_dim=2048, + num_heads=8, + ) + + self.num_multimask_outputs = num_multimask_outputs + + if with_iou: + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + if with_iou: + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/ext/sam/prompt_encoder.py b/ext/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c3143f4f8e02ddd7ca8587b40ff5d47c3a6b7ef3 --- /dev/null +++ b/ext/sam/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/ext/sam/transformer.py b/ext/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..28fafea52288603fea275f3a100790471825c34a --- /dev/null +++ b/ext/sam/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/ext/templates/__init__.py b/ext/templates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a74442361a8e1d9b18aac0dda26ed7fc750ef562 --- /dev/null +++ b/ext/templates/__init__.py @@ -0,0 +1 @@ +from .vild import VILD_PROMPT diff --git a/ext/templates/vild.py b/ext/templates/vild.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3eae4ee3a02cd5d57e5b572f96bbb338b75141 --- /dev/null +++ b/ext/templates/vild.py @@ -0,0 +1,17 @@ +# https://github.com/bytedance/fc-clip/blob/93f3122518e8a3ef98926e5ea761a776d5050430/fcclip/fcclip.py#L26C1-L41C2 +VILD_PROMPT = [ + "a photo of a {}.", + "This is a photo of a {}", + "There is a {} in the scene", + "There is the {} in the scene", + "a photo of a {} in the scene", + "a photo of a small {}.", + "a photo of a medium {}.", + "a photo of a large {}.", + "This is a photo of a small {}.", + "This is a photo of a medium {}.", + "This is a photo of a large {}.", + "There is a small {} in the scene.", + "There is a medium {} in the scene.", + "There is a large {} in the scene.", +] diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..a084857e67f387231cf327c3a9dd529ae5796f80 --- /dev/null +++ b/main.py @@ -0,0 +1,240 @@ +import gradio as gr + +import numpy as np + +import torch +import torch.nn.functional as F +from PIL import Image + +# mm libs +from mmdet.registry import MODELS +from mmdet.structures import DetDataSample +from mmdet.visualization import DetLocalVisualizer +from mmengine import Config, print_log +from mmengine.structures import InstanceData + +from mmdet.datasets.coco_panoptic import CocoPanopticDataset + +from PIL import ImageDraw + +IMG_SIZE = 1024 + +TITLE = "
OMG-Seg: Is One Model Good Enough For All Segmentation?
" +CSS = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" + +model_cfg = Config.fromfile('app/configs/m2_convl.py') + +model = MODELS.build(model_cfg.model) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = model.to(device=device) +model = model.eval() +model.init_weights() + +mean = torch.tensor([123.675, 116.28, 103.53], device=device)[:, None, None] +std = torch.tensor([58.395, 57.12, 57.375], device=device)[:, None, None] + +visualizer = DetLocalVisualizer() + +examples = [ + ["assets/000000000139.jpg"], + ["assets/000000000285.jpg"], + ["assets/000000000632.jpg"], + ["assets/000000000724.jpg"], +] + + +class IMGState: + def __init__(self): + self.img = None + self.selected_points = [] + self.available_to_set = True + + def set_img(self, img): + self.img = img + self.available_to_set = False + + def clear(self): + self.img = None + self.selected_points = [] + self.available_to_set = True + + def clean(self): + self.selected_points = [] + + @property + def available(self): + return self.available_to_set + + @classmethod + def cls_clean(cls, state): + state.clean() + return Image.fromarray(state.img), None + + @classmethod + def cls_clear(cls, state): + state.clear() + return None, None + + +def store_img(img, img_state): + w, h = img.size + scale = IMG_SIZE / max(w, h) + new_w = int(w * scale) + new_h = int(h * scale) + img = img.resize((new_w, new_h), resample=Image.Resampling.BILINEAR) + img_numpy = np.array(img) + img_state.set_img(img_numpy) + print_log(f"Successfully loaded an image with size {new_w} x {new_h}", logger='current') + + return img, None + + +def get_points_with_draw(image, img_state, evt: gr.SelectData): + x, y = evt.index[0], evt.index[1] + print_log(f"Point: {x}_{y}", logger='current') + point_radius, point_color = 10, (97, 217, 54) + + img_state.selected_points.append([x, y]) + if len(img_state.selected_points) > 0: + img_state.selected_points = img_state.selected_points[-1:] + image = Image.fromarray(img_state.img) + + draw = ImageDraw.Draw(image) + draw.ellipse( + [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], + fill=point_color, + ) + return image + + +def segment_point(image, img_state): + output_img = img_state.img + h, w = output_img.shape[:2] + + img_tensor = torch.tensor(output_img, device=device, dtype=torch.float32).permute((2, 0, 1))[None] + img_tensor = (img_tensor - mean) / std + + im_w = w if w % 32 == 0 else w // 32 * 32 + 32 + im_h = h if h % 32 == 0 else h // 32 * 32 + 32 + img_tensor = F.pad(img_tensor, (0, im_w - w, 0, im_h - h), 'constant', 0) + + if len(img_state.selected_points) > 0: + input_points = torch.tensor(img_state.selected_points, dtype=torch.float32, device=device) + batch_data_samples = [DetDataSample()] + selected_point = torch.cat([input_points - 3, input_points + 3], 1) + gt_instances = InstanceData( + point_coords=selected_point, + ) + pb_labels = torch.ones(len(gt_instances), dtype=torch.long, device=device) + gt_instances.pb_labels = pb_labels + batch_data_samples[0].gt_instances_collected = gt_instances + batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w))) + batch_data_samples[0].set_metainfo(dict(img_shape=(h, w))) + is_prompt = True + else: + batch_data_samples = [DetDataSample()] + batch_data_samples[0].set_metainfo(dict(batch_input_shape=(im_h, im_w))) + batch_data_samples[0].set_metainfo(dict(img_shape=(h, w))) + is_prompt = False + with torch.no_grad(): + masks, cls_pred = model.predict_with_point(img_tensor, batch_data_samples) + + assert len(masks) == 1 + masks = masks[0] + if is_prompt: + masks = masks[0, :h, :w] + masks = masks > 0. # no sigmoid + rgb_shape = tuple(list(masks.shape) + [3]) + color = np.zeros(rgb_shape, dtype=np.uint8) + color[masks] = np.array([97, 217, 54]) + output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8) + output_img = Image.fromarray(output_img) + else: + output_img = visualizer._draw_panoptic_seg( + output_img, + masks['pan_results'].to('cpu').numpy(), + classes=CocoPanopticDataset.METAINFO['classes'], + palette=CocoPanopticDataset.METAINFO['palette'] + ) + return image, output_img + + +def register_title(): + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown(TITLE) + + +def register_point_mode(): + with gr.Tab("Point mode"): + img_state = gr.State(IMGState()) + with gr.Row(variant="panel"): + with gr.Column(scale=1): + img_p = gr.Image(label="Input Image", type="pil") + + with gr.Column(scale=1): + segm_p = gr.Image(label="Segment", interactive=False, type="pil") + + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(): + segment_btn = gr.Button("Segment", variant="primary") + with gr.Column(): + clean_btn = gr.Button("Clean Prompts", variant="secondary") + + with gr.Row(): + with gr.Column(): + gr.Markdown("Try some of the examples below ⬇️") + gr.Examples( + examples=examples, + inputs=[img_p, img_state], + outputs=[img_p, segm_p], + examples_per_page=4, + fn=store_img, + run_on_click=True + ) + + img_p.upload( + store_img, + [img_p, img_state], + [img_p, segm_p] + ) + + img_p.select( + get_points_with_draw, + [img_p, img_state], + img_p + ) + + segment_btn.click( + segment_point, + [img_p, img_state], + [img_p, segm_p] + ) + + clean_btn.click( + IMGState.cls_clean, + img_state, + [img_p, segm_p] + ) + + img_p.clear( + IMGState.cls_clear, + img_state, + [img_p, segm_p] + ) + + +def build_demo(): + with gr.Blocks(css=CSS, title="RAP-SAM") as _demo: + register_title() + register_point_mode() + return _demo + + +if __name__ == '__main__': + demo = build_demo() + + demo.queue(api_open=False) + demo.launch(server_name='0.0.0.0') diff --git a/models/convnext_large_d_320_CocoPanopticOVDataset.pth b/models/convnext_large_d_320_CocoPanopticOVDataset.pth new file mode 100644 index 0000000000000000000000000000000000000000..ae4076f97d2b541acf5af57361b64d58f95929b8 --- /dev/null +++ b/models/convnext_large_d_320_CocoPanopticOVDataset.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:419c94281ff10061be69a89007564be71c5862bec4751ee209d89c0c758fe029 +size 6947275 diff --git a/models/omg_seg_convl.pth b/models/omg_seg_convl.pth new file mode 100644 index 0000000000000000000000000000000000000000..876525bb0f59f52bc4118002b4e2439233e3fb9c --- /dev/null +++ b/models/omg_seg_convl.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c0cdea9e7ff19566c21645bfe522837d4fe2d9baead13041511ddb09f9968ed +size 84182884 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d83463f2cccded00a9c26cd34c423b1654b9e68d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +-f https://download.openmmlab.com/mmcv/dist/cu121/torch2.1.0/index.html + +torch==2.1.2 +torchvision +mmengine==0.10.2 +mmcv==2.1.0 +mmdet==3.3.0 +ftfy +timm +regex diff --git a/seg/configs/_base_/datasets/ade_panoptic.py b/seg/configs/_base_/datasets/ade_panoptic.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ac69f59f6546686303ba1d3698e2df6295d2d5 --- /dev/null +++ b/seg/configs/_base_/datasets/ade_panoptic.py @@ -0,0 +1,95 @@ +# dataset settings +from mmcv.transforms import LoadImageFromFile, RandomResize +from mmengine.dataset import DefaultSampler + +from mmdet.datasets import AspectRatioBatchSampler +from mmdet.datasets.transforms import LoadPanopticAnnotations, RandomFlip, RandomCrop, PackDetInputs, Resize +from mmdet.evaluation import CocoPanopticMetric + +from mmdet.datasets.ade20k import ADE20KPanopticDataset + + +data_root = 'data/ade/' +backend_args = None +image_size = (1024, 1024) + +train_pipeline = [ + dict( + type=LoadImageFromFile, + to_float32=True, + backend_args=backend_args), + dict( + type=LoadPanopticAnnotations, + with_bbox=True, + with_mask=True, + with_seg=True, + backend_args=backend_args), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomResize, + resize_type=Resize, + scale=image_size, + ratio_range=(0.1, 2.0), + keep_ratio=True, + ), + dict( + type=RandomCrop, + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict(type=PackDetInputs) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=ADE20KPanopticDataset, + data_root=data_root, + ann_file='ADEChallengeData2016/ade20k_panoptic_train.json', + data_prefix=dict(img='ADEChallengeData2016/images/training/', + seg='ADEChallengeData2016/ade20k_panoptic_train/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args + ) +) + +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(2560, 640), keep_ratio=True), + dict(type=LoadPanopticAnnotations, backend_args=backend_args), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') + ) +] +val_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=ADE20KPanopticDataset, + data_root=data_root, + ann_file='ADEChallengeData2016/ade20k_panoptic_val.json', + data_prefix=dict(img='ADEChallengeData2016/images/validation/', + seg='ADEChallengeData2016/ade20k_panoptic_val/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args + ) +) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=CocoPanopticMetric, + ann_file=data_root + 'ADEChallengeData2016/ade20k_panoptic_val.json', + seg_prefix=data_root + 'ADEChallengeData2016/ade20k_panoptic_val/', + backend_args=backend_args +) +test_evaluator = val_evaluator diff --git a/seg/configs/_base_/datasets/ade_panoptic_ov.py b/seg/configs/_base_/datasets/ade_panoptic_ov.py new file mode 100644 index 0000000000000000000000000000000000000000..aa168fc91c33a418454181d15b382debe2d8380c --- /dev/null +++ b/seg/configs/_base_/datasets/ade_panoptic_ov.py @@ -0,0 +1,94 @@ +# dataset settings +from mmcv.transforms import LoadImageFromFile, RandomResize +from mmengine.dataset import DefaultSampler + +from mmdet.datasets import AspectRatioBatchSampler +from mmdet.datasets.transforms import LoadPanopticAnnotations, RandomFlip, RandomCrop, PackDetInputs, Resize +from mmdet.evaluation import CocoPanopticMetric + +from seg.datasets.ade_ov import ADEPanopticOVDataset + +data_root = 'data/ade/' +backend_args = None +image_size = (1024, 1024) + +train_pipeline = [ + dict( + type=LoadImageFromFile, + to_float32=True, + backend_args=backend_args), + dict( + type=LoadPanopticAnnotations, + with_bbox=True, + with_mask=True, + with_seg=True, + backend_args=backend_args), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomResize, + resize_type=Resize, + scale=image_size, + ratio_range=(0.1, 2.0), + keep_ratio=True, + ), + dict( + type=RandomCrop, + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict(type=PackDetInputs) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=ADEPanopticOVDataset, + data_root=data_root, + ann_file='ADEChallengeData2016/ade20k_panoptic_train.json', + data_prefix=dict(img='ADEChallengeData2016/images/training/', + seg='ADEChallengeData2016/ade20k_panoptic_train/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args + ) +) + +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(2560, 640), keep_ratio=True), + dict(type=LoadPanopticAnnotations, backend_args=backend_args), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') + ) +] +val_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=ADEPanopticOVDataset, + data_root=data_root, + ann_file='ADEChallengeData2016/ade20k_panoptic_val.json', + data_prefix=dict(img='ADEChallengeData2016/images/validation/', + seg='ADEChallengeData2016/ade20k_panoptic_val/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args + ) +) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=CocoPanopticMetric, + # ann_file=data_root + 'ADEChallengeData2016/ade20k_panoptic_val.json', + seg_prefix=data_root + 'ADEChallengeData2016/ade20k_panoptic_val/', + backend_args=backend_args +) +test_evaluator = val_evaluator diff --git a/seg/configs/_base_/datasets/cityscapes_panoptic.py b/seg/configs/_base_/datasets/cityscapes_panoptic.py new file mode 100644 index 0000000000000000000000000000000000000000..b83b99b478577255e6188e1e9013d5e822af3324 --- /dev/null +++ b/seg/configs/_base_/datasets/cityscapes_panoptic.py @@ -0,0 +1,95 @@ +# dataset settings +from mmcv.transforms import LoadImageFromFile, RandomResize +from mmengine.dataset import DefaultSampler + +from mmdet.datasets import AspectRatioBatchSampler +from mmdet.datasets.transforms import RandomFlip, RandomCrop, PackDetInputs, Resize + +from seg.datasets.pipeliens.loading import LoadPanopticAnnotationsHB +from seg.datasets.cityscapes import CityscapesPanopticDataset +from seg.evaluation.metrics.cityscapes_panoptic_metric import CityscapesPanopticMetric + +data_root = 'data/cityscapes/' +backend_args = None +image_size = (1024, 1024) + +train_pipeline = [ + dict( + type=LoadImageFromFile, + to_float32=True, + backend_args=backend_args), + dict( + type=LoadPanopticAnnotationsHB, + with_bbox=True, + with_mask=True, + with_seg=True, + backend_args=backend_args), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomResize, + resize_type=Resize, + scale=image_size, + ratio_range=(.8, 1.5), + keep_ratio=True, + ), + dict( + type=RandomCrop, + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict(type=PackDetInputs) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=CityscapesPanopticDataset, + data_root=data_root, + ann_file='annotations/cityscapes_panoptic_train_trainId.json', + data_prefix=dict(img='leftImg8bit/train/', + seg='annotations/cityscapes_panoptic_train_trainId/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args + ) +) + +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(2048, 1024), keep_ratio=True), + dict(type=LoadPanopticAnnotationsHB, backend_args=backend_args), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', 'instances') + ) +] +val_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=CityscapesPanopticDataset, + data_root=data_root, + ann_file='annotations/cityscapes_panoptic_val_trainId.json', + data_prefix=dict(img='leftImg8bit/val/', + seg='annotations/cityscapes_panoptic_val_trainId/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args + ) +) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=CityscapesPanopticMetric, + ann_file=data_root + 'annotations/cityscapes_panoptic_val_trainId.json', + seg_prefix=data_root + 'annotations/cityscapes_panoptic_val_trainId/', + backend_args=backend_args +) +test_evaluator = val_evaluator diff --git a/seg/configs/_base_/datasets/coco_panoptic_lsj.py b/seg/configs/_base_/datasets/coco_panoptic_lsj.py new file mode 100644 index 0000000000000000000000000000000000000000..497d5a2e23f2a56fd80a99a37de40d0551abfaa0 --- /dev/null +++ b/seg/configs/_base_/datasets/coco_panoptic_lsj.py @@ -0,0 +1,101 @@ +# dataset settings +from mmcv.transforms import LoadImageFromFile, RandomResize +from mmengine.dataset import DefaultSampler + +from mmdet.datasets import AspectRatioBatchSampler, CocoPanopticDataset +from mmdet.datasets.transforms import LoadPanopticAnnotations, RandomFlip, RandomCrop, PackDetInputs, Resize +from mmdet.evaluation import CocoPanopticMetric, CocoMetric + +from seg.datasets.coco_ov import CocoPanopticOVDataset +from seg.datasets.pipeliens.loading import LoadPanopticAnnotationsHB + +data_root = 'data/coco/' +backend_args = None +image_size = (1024, 1024) + +train_pipeline = [ + dict( + type=LoadImageFromFile, + to_float32=True, + backend_args=backend_args), + dict( + type=LoadPanopticAnnotationsHB, + with_bbox=True, + with_mask=True, + with_seg=True, + backend_args=backend_args), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomResize, + resize_type=Resize, + scale=image_size, + ratio_range=(0.1, 2.0), + keep_ratio=True, + ), + dict( + type=RandomCrop, + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict(type=PackDetInputs) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=CocoPanopticOVDataset, + data_root=data_root, + ann_file='annotations/panoptic_train2017.json', + data_prefix=dict(img='train2017/', seg='annotations/panoptic_train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args + ) +) + +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=LoadPanopticAnnotations, backend_args=backend_args), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') + ) +] +val_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=CocoPanopticDataset, + data_root=data_root, + ann_file='annotations/panoptic_val2017.json', + data_prefix=dict(img='val2017/', seg='annotations/panoptic_val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args + ) +) +test_dataloader = val_dataloader + +val_evaluator = [ + dict( + type=CocoPanopticMetric, + ann_file=data_root + 'annotations/panoptic_val2017.json', + seg_prefix=data_root + 'annotations/panoptic_val2017/', + backend_args=backend_args + ), + dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric=['segm'], + backend_args=backend_args + ) +] +test_evaluator = val_evaluator diff --git a/seg/configs/_base_/datasets/coco_panoptic_lsj_sam.py b/seg/configs/_base_/datasets/coco_panoptic_lsj_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..8dabc440c2de8a989224949fe566ec0b69f42117 --- /dev/null +++ b/seg/configs/_base_/datasets/coco_panoptic_lsj_sam.py @@ -0,0 +1,107 @@ +# dataset settings +from mmcv.transforms import LoadImageFromFile, RandomResize +from mmengine.dataset import DefaultSampler + +from mmdet.datasets import AspectRatioBatchSampler +from mmdet.datasets.transforms import LoadPanopticAnnotations, RandomFlip, RandomCrop, PackDetInputs, Resize +from mmdet.evaluation import CocoPanopticMetric, CocoMetric + +from seg.datasets.coco_pan_sam import CocoPanopticSAMDataset +from seg.datasets.pipeliens.loading import LoadPanopticAnnotationsHB, FilterAnnotationsHB + +data_root = 'data/coco/' +backend_args = None +image_size = (1024, 1024) + +train_pipeline = [ + dict( + type=LoadImageFromFile, + to_float32=True, + backend_args=backend_args), + dict( + type=LoadPanopticAnnotationsHB, + with_bbox=True, + with_mask=True, + with_seg=True, + backend_args=backend_args), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomResize, + resize_type=Resize, + scale=image_size, + ratio_range=(0.1, 2.0), + keep_ratio=True, + ), + dict( + type=RandomCrop, + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict( + type=FilterAnnotationsHB, + by_box=False, + by_mask=True, + min_gt_mask_area=32, + ), + dict(type=PackDetInputs) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=CocoPanopticSAMDataset, + data_root=data_root, + ann_file='annotations/panoptic_train2017.json', + data_prefix=dict(img='train2017/', seg='annotations/panoptic_train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args + ) +) + +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=LoadPanopticAnnotations, backend_args=backend_args), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor') + ) +] +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=CocoPanopticSAMDataset, + data_root=data_root, + ann_file='annotations/panoptic_val2017.json', + data_prefix=dict(img='val2017/', seg='annotations/panoptic_val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args + ) +) +test_dataloader = val_dataloader + +val_evaluator = [ + dict( + type=CocoPanopticMetric, + ann_file=data_root + 'annotations/panoptic_val2017.json', + seg_prefix=data_root + 'annotations/panoptic_val2017/', + backend_args=backend_args + ), + dict( + type=CocoMetric, + ann_file=data_root + 'annotations/instances_val2017.json', + metric=['segm'], + backend_args=backend_args + ) +] +test_evaluator = val_evaluator diff --git a/seg/configs/_base_/datasets/davis.py b/seg/configs/_base_/datasets/davis.py new file mode 100644 index 0000000000000000000000000000000000000000..47d9c8277990ebc26eb59ed6d117079eafd9bda7 --- /dev/null +++ b/seg/configs/_base_/datasets/davis.py @@ -0,0 +1,96 @@ +from mmcv import TransformBroadcaster, LoadImageFromFile, RandomResize +from mmdet.datasets.transforms import Resize, RandomFlip, RandomCrop +from mmengine.dataset import DefaultSampler + +from seg.datasets.davis import DAVIS +from seg.datasets.pipeliens.frame_copy import AddSemSeg +from seg.datasets.pipeliens.loading import LoadVideoSegAnnotations, ResizeOri +from seg.datasets.pipeliens.formatting import PackVidSegInputs +from seg.datasets.pipeliens.frame_sampling import VideoClipSample +from seg.datasets.samplers.batch_sampler import VideoSegAspectRatioBatchSampler +from seg.evaluation.metrics.vos_metric import VOSMetric + +dataset_type = DAVIS +data_root = 'data/DAVIS' + +backend_args = None +image_size = (1280, 736) + +# dataset settings +train_pipeline = [ + dict( + type=VideoClipSample, + num_selected=2, + interval=2), + dict( + type=TransformBroadcaster, + share_random_params=True, + transforms=[ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadVideoSegAnnotations, with_bbox=True, with_label=True, with_mask=True, with_seg=False), + dict(type=AddSemSeg), + dict( + type=RandomResize, + resize_type=Resize, + scale=image_size, + ratio_range=(.9, 1.1), + keep_ratio=True, + ), + dict( + type=RandomCrop, + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict(type=RandomFlip, prob=0.5), + ]), + dict(type=PackVidSegInputs) +] + +test_pipeline = [ + dict( + type=TransformBroadcaster, + transforms=[ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=image_size, keep_ratio=True), + dict(type=LoadVideoSegAnnotations, with_bbox=True, with_label=True, with_mask=True, with_seg=False), + # dict(type=ResizeOri), + ]), + dict(type=PackVidSegInputs) +] + +# dataloader +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=VideoSegAspectRatioBatchSampler), + dataset=dict( + type=dataset_type, + data_root=data_root, + dataset_version='2017', + ann_file='ImageSets/2017/train.txt', + data_prefix=dict(img='JPEGImages/Full-Resolution/', ann='Annotations/Full-Resolution/'), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + dataset_version='2017', + ann_file='ImageSets/2017/val.txt', + data_prefix=dict(img='JPEGImages/480p/', ann='Annotations/480p/'), + test_mode=True, + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=VOSMetric, + format_only=True, +) +test_evaluator = val_evaluator diff --git a/seg/configs/_base_/datasets/joint_dataset.py b/seg/configs/_base_/datasets/joint_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..00b0aacf992434c46ad3bbbbbe0f3ed44a302ed8 --- /dev/null +++ b/seg/configs/_base_/datasets/joint_dataset.py @@ -0,0 +1,61 @@ +# dataset settings +# do not use this config for training, it is only used to create embedding. +from mmengine import read_base +from mmengine.dataset import DefaultSampler, RepeatDataset + +from seg.datasets.concat_dataset import ConcatOVDataset +from seg.datasets.samplers.batch_sampler import VideoSegAspectRatioBatchSampler + +with read_base(): + from .coco_panoptic_lsj import train_dataloader as _coco_vid_train_loader + from .ade_panoptic_ov import train_dataloader as _ade_train_loader + from .youtube_vis_2019 import train_dataloader as _yt19_train_loader + from .youtube_vis_2021 import train_dataloader as _yt21_train_loader + from .vipseg import train_dataloader as _vipseg_train_loader + from .cityscapes_panoptic import train_dataloader as _city_train_loader + from .coco_panoptic_lsj import val_dataloader, val_evaluator, test_dataloader, test_evaluator + from .youtube_vis_2019 import image_size + +train_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=VideoSegAspectRatioBatchSampler), + dataset=dict( + type=ConcatOVDataset, + datasets=[ + dict( + type=RepeatDataset, + dataset=_coco_vid_train_loader.dataset, + times=1, + ), + dict( + type=RepeatDataset, + dataset=_ade_train_loader.dataset, + times=1, + ), + dict( + type=RepeatDataset, + dataset=_yt19_train_loader.dataset, + times=1, + ), + dict( + type=RepeatDataset, + dataset=_yt21_train_loader.dataset, + times=1, + ), + dict( + type=RepeatDataset, + dataset=_vipseg_train_loader.dataset, + times=1, + ), + dict( + type=RepeatDataset, + dataset=_city_train_loader.dataset, + times=1, + ), + ], + ) +) + diff --git a/seg/configs/_base_/datasets/objects365v2_detection_lsj.py b/seg/configs/_base_/datasets/objects365v2_detection_lsj.py new file mode 100644 index 0000000000000000000000000000000000000000..5ada3b72527a3e11072997ebeddcd390405100e5 --- /dev/null +++ b/seg/configs/_base_/datasets/objects365v2_detection_lsj.py @@ -0,0 +1,100 @@ +# dataset settings +from mmcv import LoadImageFromFile, RandomResize +from mmdet.datasets import Objects365V2Dataset, AspectRatioBatchSampler +from mmdet.datasets.transforms import Resize, RandomFlip, PackDetInputs, RandomCrop +from mmdet.datasets.transforms import LoadAnnotations +from mmdet.evaluation import CocoMetric +from mmengine.dataset import DefaultSampler + +dataset_type = Objects365V2Dataset +data_root = "s3://wangyudong/obj365_v2/" + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +image_size = (1024, 1024) +train_pipeline = [ + dict( + type=LoadImageFromFile, + to_float32=True, + backend_args=backend_args), + dict( + type=LoadAnnotations, + with_bbox=True, + with_mask=False, + with_seg=False, + backend_args=backend_args), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomResize, + resize_type=Resize, + scale=image_size, + ratio_range=(0.1, 2.0), + keep_ratio=True, + ), + dict( + type=RandomCrop, + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=LoadAnnotations, with_bbox=True), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/zhiyuan_objv2_train.json', + data_prefix=dict(img='train/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/zhiyuan_objv2_val.json', + data_prefix=dict(img='val/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=CocoMetric, + ann_file=data_root + 'annotations/zhiyuan_objv2_val.json', + metric='bbox', + format_only=False, + backend_args=backend_args) +test_evaluator = val_evaluator diff --git a/seg/configs/_base_/datasets/objects365v2_instance_lsj.py b/seg/configs/_base_/datasets/objects365v2_instance_lsj.py new file mode 100644 index 0000000000000000000000000000000000000000..a2eee683ce011e2ea250ee928cbf687b2a33ec13 --- /dev/null +++ b/seg/configs/_base_/datasets/objects365v2_instance_lsj.py @@ -0,0 +1,106 @@ +# dataset settings +from mmcv import LoadImageFromFile, RandomResize +from mmdet.datasets import AspectRatioBatchSampler +from mmdet.datasets.transforms import Resize, RandomFlip, PackDetInputs, RandomCrop +from mmdet.datasets.transforms import LoadAnnotations +from mmdet.evaluation import CocoMetric +from mmengine.dataset import DefaultSampler + +from seg.datasets.pipeliens.frame_copy import AddSemSeg +from seg.datasets.objects365 import Objects365V2InsDataset + +dataset_type = Objects365V2InsDataset +data_root = "" + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +image_size = (1024, 1024) +train_pipeline = [ + dict( + type=LoadImageFromFile, + to_float32=True, + backend_args=backend_args), + dict( + type=LoadAnnotations, + with_bbox=True, + with_mask=True, + with_seg=False, + backend_args=backend_args), + dict(type=AddSemSeg), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomResize, + resize_type=Resize, + scale=image_size, + ratio_range=(0.1, 2.0), + keep_ratio=True, + ), + dict( + type=RandomCrop, + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict(type=PackDetInputs) +] +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + dict(type=LoadAnnotations, with_bbox=True), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='data/obj365_ins/obj365v2_train.gt.json', + data_prefix=dict(img='s3://wangyudong/obj365_v2/train/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args) +) + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/zhiyuan_objv2_val.json', + data_prefix=dict(img='val/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=CocoMetric, + ann_file=data_root + 'annotations/zhiyuan_objv2_val.json', + metric='bbox', + format_only=False, + backend_args=backend_args) +test_evaluator = val_evaluator diff --git a/seg/configs/_base_/datasets/v3det.py b/seg/configs/_base_/datasets/v3det.py new file mode 100644 index 0000000000000000000000000000000000000000..f28552661f69953e3db227645f460534a5c8255c --- /dev/null +++ b/seg/configs/_base_/datasets/v3det.py @@ -0,0 +1,106 @@ +# dataset settings +from mmcv import LoadImageFromFile, RandomResize +from mmdet.datasets import AspectRatioBatchSampler +from mmdet.datasets.transforms import LoadAnnotations, RandomFlip, Resize, RandomCrop, PackDetInputs +from mmdet.evaluation import CocoMetric +from mmengine.dataset import DefaultSampler + +from seg.datasets.pipeliens.loading import FilterAnnotationsHB +from seg.datasets.v3det import V3DetDataset + +dataset_type = V3DetDataset +data_root = 'data/V3Det/' + +backend_args = None + + +image_size = (1024, 1024) +train_pipeline = [ + dict( + type=LoadImageFromFile, + to_float32=True, + backend_args=backend_args), + dict( + type=LoadAnnotations, + with_bbox=True, + with_mask=False, + with_seg=False, + backend_args=backend_args), + dict(type=RandomFlip, prob=0.5), + dict( + type=RandomResize, + resize_type=Resize, + scale=image_size, + ratio_range=(0.1, 2.0), + keep_ratio=True, + ), + dict( + type=RandomCrop, + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict( + type=FilterAnnotationsHB, + by_box=True, + by_mask=False, + min_gt_bbox_wh=(8, 8) + ), + dict(type=PackDetInputs) +] + +test_pipeline = [ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=(1333, 800), keep_ratio=True), + # If you don't have a gt annotation, delete the pipeline + dict(type=LoadAnnotations, with_bbox=True), + dict( + type=PackDetInputs, + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=AspectRatioBatchSampler), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/v3det_2023_v1_train.json', + data_prefix=dict(img=''), + filter_cfg=dict(filter_empty_gt=True, min_size=4), + pipeline=train_pipeline, + backend_args=backend_args + ) +) + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/v3det_2023_v1_val.json', + data_prefix=dict(img=''), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args + ) +) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=CocoMetric, + ann_file=data_root + 'annotations/v3det_2023_v1_val.json', + metric='bbox', + format_only=False, + backend_args=backend_args, + use_mp_eval=True, + proposal_nums=[300] +) +test_evaluator = val_evaluator diff --git a/seg/configs/_base_/datasets/vipseg.py b/seg/configs/_base_/datasets/vipseg.py new file mode 100644 index 0000000000000000000000000000000000000000..fef9c57a151155be477c4d0312f99673ba880968 --- /dev/null +++ b/seg/configs/_base_/datasets/vipseg.py @@ -0,0 +1,94 @@ +from mmcv import TransformBroadcaster, LoadImageFromFile, RandomResize +from mmdet.datasets.transforms import Resize, RandomFlip, RandomCrop +from mmengine.dataset import DefaultSampler + +from seg.datasets.pipeliens.loading import LoadVideoSegAnnotations, ResizeOri +from seg.datasets.pipeliens.formatting import PackVidSegInputs +from seg.datasets.pipeliens.frame_sampling import VideoClipSample +from seg.datasets.samplers.batch_sampler import VideoSegAspectRatioBatchSampler +from seg.datasets.vipseg import VIPSegDataset +from seg.evaluation.metrics.vip_seg_metric import VIPSegMetric + +dataset_type = VIPSegDataset +data_root = 'data/VIPSeg' + +backend_args = None +image_size = (1280, 736) + +# dataset settings +train_pipeline = [ + dict( + type=VideoClipSample, + num_selected=2, + interval=2), + dict( + type=TransformBroadcaster, + share_random_params=True, + transforms=[ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadVideoSegAnnotations, with_bbox=True, with_label=True, with_mask=True, with_seg=True), + dict( + type=RandomResize, + resize_type=Resize, + scale=image_size, + ratio_range=(.8, 2.), + keep_ratio=True, + ), + dict( + type=RandomCrop, + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict(type=RandomFlip, prob=0.5), + ]), + dict(type=PackVidSegInputs) +] + +test_pipeline = [ + dict( + type=TransformBroadcaster, + transforms=[ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadVideoSegAnnotations, with_bbox=True, with_label=True, with_mask=True, with_seg=True), + dict(type=Resize, scale=image_size, keep_ratio=True), + dict(type=ResizeOri), + ]), + dict(type=PackVidSegInputs) +] + +# dataloader +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=VideoSegAspectRatioBatchSampler), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='train.txt', + data_prefix=dict(img='imgs/', seg='panomasks/'), + # check whether it is necessary. + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='val.txt', + data_prefix=dict(img='imgs/', seg='panomasks/'), + test_mode=True, + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=VIPSegMetric, + metric=['VPQ@1', 'VPQ@2', 'VPQ@4', 'VPQ@6'], +) +test_evaluator = val_evaluator diff --git a/seg/configs/_base_/datasets/youtube_vis_2019.py b/seg/configs/_base_/datasets/youtube_vis_2019.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac8ca1f4d3060eb0c16bd88b8bf5716d1d6a6d4 --- /dev/null +++ b/seg/configs/_base_/datasets/youtube_vis_2019.py @@ -0,0 +1,99 @@ +from mmcv import TransformBroadcaster, LoadImageFromFile, RandomResize +from mmdet.datasets.transforms import LoadTrackAnnotations, Resize, RandomFlip, PackTrackInputs, RandomCrop +from mmdet.evaluation import YouTubeVISMetric +from mmengine.dataset import DefaultSampler + +from seg.datasets.youtube_vis_dataset import YouTubeVISDatasetV2 +from seg.datasets.pipeliens.formatting import PackVidSegInputs +from seg.datasets.pipeliens.frame_copy import AddSemSeg +from seg.datasets.pipeliens.frame_sampling import VideoClipSample +from seg.datasets.samplers.batch_sampler import VideoSegAspectRatioBatchSampler + +dataset_type = YouTubeVISDatasetV2 +data_root = 'data/youtube_vis_2019/' +dataset_version = data_root[-5:-1] # 2019 or 2021 + +backend_args = None +image_size = (1280, 736) + +# dataset settings +train_pipeline = [ + dict( + type=VideoClipSample, + num_selected=2, + interval=2), + dict( + type=TransformBroadcaster, + share_random_params=True, + transforms=[ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadTrackAnnotations, with_mask=True), + dict(type=AddSemSeg), + dict( + type=RandomResize, + resize_type=Resize, + scale=image_size, + ratio_range=(.8, 2.), + keep_ratio=True, + ), + dict( + type=RandomCrop, + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict(type=RandomFlip, prob=0.5), + ]), + dict(type=PackVidSegInputs) +] + +test_pipeline = [ + dict( + type=TransformBroadcaster, + transforms=[ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=image_size, keep_ratio=True), + dict(type=LoadTrackAnnotations, with_mask=True), + ]), + dict(type=PackTrackInputs) +] + +# dataloader +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=VideoSegAspectRatioBatchSampler), + dataset=dict( + type=dataset_type, + data_root=data_root, + dataset_version=dataset_version, + ann_file='annotations/youtube_vis_2019_train.json', + data_prefix=dict(img_path='train/JPEGImages'), + # check whether it is necessary. + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + dataset_version=dataset_version, + ann_file='annotations/youtube_vis_2019_valid.json', + data_prefix=dict(img_path='valid/JPEGImages'), + test_mode=True, + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=YouTubeVISMetric, + metric='youtube_vis_ap', + outfile_prefix='./youtube_vis_2019_results', + format_only=True +) +test_evaluator = val_evaluator diff --git a/seg/configs/_base_/datasets/youtube_vis_2021.py b/seg/configs/_base_/datasets/youtube_vis_2021.py new file mode 100644 index 0000000000000000000000000000000000000000..bae6b1a27a0122c6069936feb385fc01e7e12e68 --- /dev/null +++ b/seg/configs/_base_/datasets/youtube_vis_2021.py @@ -0,0 +1,100 @@ +from mmcv import TransformBroadcaster, LoadImageFromFile, RandomResize +from mmdet.datasets.transforms import LoadTrackAnnotations, Resize, RandomFlip, PackTrackInputs, RandomCrop +from mmdet.evaluation import YouTubeVISMetric +from mmengine.dataset import DefaultSampler + + +from seg.datasets.youtube_vis_dataset import YouTubeVISDatasetV2 +from seg.datasets.pipeliens.formatting import PackVidSegInputs +from seg.datasets.pipeliens.frame_copy import AddSemSeg +from seg.datasets.pipeliens.frame_sampling import VideoClipSample +from seg.datasets.samplers.batch_sampler import VideoSegAspectRatioBatchSampler + +dataset_type = YouTubeVISDatasetV2 +data_root = 'data/youtube_vis_2021/' +dataset_version = data_root[-5:-1] # 2019 or 2021 + +backend_args = None +image_size = (1280, 736) + +# dataset settings +train_pipeline = [ + dict( + type=VideoClipSample, + num_selected=2, + interval=2), + dict( + type=TransformBroadcaster, + share_random_params=True, + transforms=[ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadTrackAnnotations, with_mask=True), + dict(type=AddSemSeg), + dict( + type=RandomResize, + resize_type=Resize, + scale=image_size, + ratio_range=(.8, 2.), + keep_ratio=True, + ), + dict( + type=RandomCrop, + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict(type=RandomFlip, prob=0.5), + ]), + dict(type=PackVidSegInputs) +] + +test_pipeline = [ + dict( + type=TransformBroadcaster, + transforms=[ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=image_size, keep_ratio=True), + dict(type=LoadTrackAnnotations, with_mask=True), + ]), + dict(type=PackTrackInputs) +] + +# dataloader +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=VideoSegAspectRatioBatchSampler), + dataset=dict( + type=dataset_type, + data_root=data_root, + dataset_version=dataset_version, + ann_file='annotations/youtube_vis_2021_train.json', + data_prefix=dict(img_path='train/JPEGImages'), + # check whether it is necessary. + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + dataset_version=dataset_version, + ann_file='annotations/youtube_vis_2021_valid.json', + data_prefix=dict(img_path='valid/JPEGImages'), + test_mode=True, + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=YouTubeVISMetric, + metric='youtube_vis_ap', + outfile_prefix='./youtube_vis_2021_results', + format_only=True +) +test_evaluator = val_evaluator diff --git a/seg/configs/_base_/datasets/youtube_vis_ovis.py b/seg/configs/_base_/datasets/youtube_vis_ovis.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b35141c01baf3be0462b58dac7d58cd5787b45 --- /dev/null +++ b/seg/configs/_base_/datasets/youtube_vis_ovis.py @@ -0,0 +1,100 @@ +from mmcv import TransformBroadcaster, LoadImageFromFile, RandomResize +from mmdet.datasets.transforms import LoadTrackAnnotations, Resize, RandomFlip, PackTrackInputs, RandomCrop +from mmdet.evaluation import YouTubeVISMetric +from mmengine.dataset import DefaultSampler + + +from seg.datasets.youtube_vis_dataset import YouTubeVISDatasetV2 +from seg.datasets.pipeliens.formatting import PackVidSegInputs +from seg.datasets.pipeliens.frame_copy import AddSemSeg +from seg.datasets.pipeliens.frame_sampling import VideoClipSample +from seg.datasets.samplers.batch_sampler import VideoSegAspectRatioBatchSampler + +dataset_type = YouTubeVISDatasetV2 +data_root = 'data/ovis/' +dataset_version = data_root[-5:-1] # 2019 or 2021 + +backend_args = None +image_size = (1280, 736) + +# dataset settings +train_pipeline = [ + dict( + type=VideoClipSample, + num_selected=2, + interval=2), + dict( + type=TransformBroadcaster, + share_random_params=True, + transforms=[ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=LoadTrackAnnotations, with_mask=True), + dict(type=AddSemSeg), + dict( + type=RandomResize, + resize_type=Resize, + scale=image_size, + ratio_range=(.8, 2.), + keep_ratio=True, + ), + dict( + type=RandomCrop, + crop_size=image_size, + crop_type='absolute', + recompute_bbox=True, + allow_negative_crop=True), + dict(type=RandomFlip, prob=0.5), + ]), + dict(type=PackVidSegInputs) +] + +test_pipeline = [ + dict( + type=TransformBroadcaster, + transforms=[ + dict(type=LoadImageFromFile, backend_args=backend_args), + dict(type=Resize, scale=image_size, keep_ratio=True), + dict(type=LoadTrackAnnotations, with_mask=True), + ]), + dict(type=PackTrackInputs) +] + +# dataloader +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=True), + batch_sampler=dict(type=VideoSegAspectRatioBatchSampler), + dataset=dict( + type=dataset_type, + data_root=data_root, + dataset_version=dataset_version, + ann_file='annotations/youtube_vis_ovis_train.json', + data_prefix=dict(img_path='train/'), + # check whether it is necessary. + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + dataset_version=dataset_version, + ann_file='annotations/youtube_vis_ovis_valid.json', + data_prefix=dict(img_path='valid/'), + test_mode=True, + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type=YouTubeVISMetric, + metric='youtube_vis_ap', + outfile_prefix='./youtube_vis_ovis_results', + format_only=True +) +test_evaluator = val_evaluator diff --git a/seg/configs/_base_/default_runtime.py b/seg/configs/_base_/default_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..2dcfb27988cbabd636be5a90cbb809fab5f95ee0 --- /dev/null +++ b/seg/configs/_base_/default_runtime.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.runner import LogProcessor +from mmengine.visualization import LocalVisBackend + +from mmdet.engine.hooks import DetVisualizationHook +from mmdet.visualization import DetLocalVisualizer + +default_scope = None + +default_hooks = dict( + timer=dict(type=IterTimerHook), + logger=dict(type=LoggerHook, interval=50), + param_scheduler=dict(type=ParamSchedulerHook), + checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=1), + sampler_seed=dict(type=DistSamplerSeedHook), + visualization=dict(type=DetVisualizationHook)) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) + +vis_backends = [dict(type=LocalVisBackend)] +visualizer = dict( + type=DetLocalVisualizer, vis_backends=vis_backends, name='visualizer') +log_processor = dict(type=LogProcessor, window_size=50, by_epoch=True) + +log_level = 'INFO' +load_from = None +resume = False diff --git a/seg/configs/_base_/schedules/schedule_12e.py b/seg/configs/_base_/schedules/schedule_12e.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc78fb293a57d5f02404c057368c5c3b64e8fd5 --- /dev/null +++ b/seg/configs/_base_/schedules/schedule_12e.py @@ -0,0 +1,59 @@ +from mmengine.optim import LinearLR, MultiStepLR, OptimWrapper +from mmengine.runner import EpochBasedTrainLoop, ValLoop, TestLoop +from torch.optim import AdamW + +# training schedule for 50e +train_cfg = dict( + type=EpochBasedTrainLoop, + max_epochs=12, + val_interval=2, +) +val_cfg = dict(type=ValLoop) +test_cfg = dict(type=TestLoop) + +# learning rate +param_scheduler = [ + dict( + type=LinearLR, + start_factor=0.001, + by_epoch=False, + begin=0, + end=500 + ), + dict( + type=MultiStepLR, + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1 + ) +] + +_embed_multi = dict(lr_mult=1.0, decay_mult=0.0) +optim_wrapper = dict( + type=OptimWrapper, + optimizer=dict( + type=AdamW, + lr=0.0001, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999) + ), + paramwise_cfg=dict( + custom_keys={ + 'backbone': dict(lr_mult=0.1, decay_mult=1.0), + 'query_embed': _embed_multi, + 'query_feat': _embed_multi, + 'level_embed': _embed_multi, + }, + norm_decay_mult=0.0 + ), + clip_grad=dict(max_norm=0.01, norm_type=2) +) + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). +auto_scale_lr = dict(enable=True, base_batch_size=16) diff --git a/seg/configs/m2ov_val/datasets/ade.py b/seg/configs/m2ov_val/datasets/ade.py new file mode 100644 index 0000000000000000000000000000000000000000..40223c0ad8edfe9912f49890286ddbe416cf9cf3 --- /dev/null +++ b/seg/configs/m2ov_val/datasets/ade.py @@ -0,0 +1,41 @@ +from mmdet.models import BatchFixedSizePad +from mmengine import read_base + +from seg.models.data_preprocessor import VideoSegDataPreprocessor + +with read_base(): + from ..._base_.default_runtime import * + from ..._base_.schedules.schedule_12e import * + from ..._base_.datasets.ade_panoptic_ov import train_dataloader, image_size + from ..._base_.datasets.ade_panoptic import val_dataloader, val_evaluator, test_dataloader, test_evaluator + from ..._base_.datasets.joint_dataset import train_dataloader as training_loader + +batch_augments = [ + dict( + type=BatchFixedSizePad, + size=(image_size[1], image_size[0]), + img_pad_value=0, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255 + ) +] +data_preprocessor = dict( + type=VideoSegDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255, + batch_augments=batch_augments +) + +num_things_classes = 100 +num_stuff_classes = 50 +num_classes = num_things_classes + num_stuff_classes + +ov_datasets_name = 'ADEPanopticOVDataset' diff --git a/seg/configs/m2ov_val/datasets/cityscapes.py b/seg/configs/m2ov_val/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..d4e7474c7e657fdd727c0d96511155a602accf74 --- /dev/null +++ b/seg/configs/m2ov_val/datasets/cityscapes.py @@ -0,0 +1,41 @@ +from mmengine.config import read_base + +from mmdet.models import BatchFixedSizePad + +from seg.models.data_preprocessor import VideoSegDataPreprocessor +from seg.models.utils import NO_OBJ + +with read_base(): + from ..._base_.default_runtime import * + from ..._base_.datasets.cityscapes_panoptic import * + from ..._base_.schedules.schedule_12e import * + +batch_augments = [ + dict( + type=BatchFixedSizePad, + size=(image_size[1], image_size[0]), + img_pad_value=0, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255 + ) +] +data_preprocessor = dict( + type=VideoSegDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=NO_OBJ, + batch_augments=batch_augments +) + +num_things_classes = 11 +num_stuff_classes = 8 +num_classes = num_things_classes + num_stuff_classes + +ov_datasets_name = 'CityscapesPanopticDataset' diff --git a/seg/configs/m2ov_val/datasets/coco.py b/seg/configs/m2ov_val/datasets/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..e894aea5b0a41b39f0fa022634cf85d83822abcf --- /dev/null +++ b/seg/configs/m2ov_val/datasets/coco.py @@ -0,0 +1,40 @@ +from mmengine.config import read_base + +from mmdet.models import BatchFixedSizePad + +from seg.models.data_preprocessor import VideoSegDataPreprocessor + +with read_base(): + from ..._base_.default_runtime import * + from ..._base_.datasets.coco_panoptic_lsj import * + from ..._base_.schedules.schedule_12e import * + +batch_augments = [ + dict( + type=BatchFixedSizePad, + size=(image_size[1], image_size[0]), + img_pad_value=0, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255 + ) +] +data_preprocessor = dict( + type=VideoSegDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255, + batch_augments=batch_augments +) + +num_things_classes = 80 +num_stuff_classes = 53 +num_classes = num_things_classes + num_stuff_classes + +ov_datasets_name = 'CocoPanopticOVDataset' diff --git a/seg/configs/m2ov_val/datasets/coco_pan_point.py b/seg/configs/m2ov_val/datasets/coco_pan_point.py new file mode 100644 index 0000000000000000000000000000000000000000..0b280998e8f06b339fa79d1f2f643347a859c126 --- /dev/null +++ b/seg/configs/m2ov_val/datasets/coco_pan_point.py @@ -0,0 +1,35 @@ +from mmengine.config import read_base + +from seg.evaluation.metrics.ins_cls_iou_metric import InsClsIoUMetric +from seg.models.data_preprocessor import OVSAMDataPreprocessor + +with read_base(): + from ..._base_.default_runtime import * + from ..._base_.datasets.coco_panoptic_lsj_sam import * + from ..._base_.schedules.schedule_12e import * + +data_preprocessor = dict( + type=OVSAMDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255, + batch_augments=None, + use_point_pseudo_box=True +) + +num_things_classes = 80 +num_stuff_classes = 0 +num_classes = num_things_classes + num_stuff_classes + +ov_datasets_name = 'CocoPanopticOVDataset' + +val_evaluator = dict( + type=InsClsIoUMetric, + with_score=False, +) +test_evaluator = val_evaluator diff --git a/seg/configs/m2ov_val/datasets/davis.py b/seg/configs/m2ov_val/datasets/davis.py new file mode 100644 index 0000000000000000000000000000000000000000..2829ccf083646cd96783ae410ee849bfaaa28ae4 --- /dev/null +++ b/seg/configs/m2ov_val/datasets/davis.py @@ -0,0 +1,41 @@ +from mmengine.config import read_base + +from mmdet.models import BatchFixedSizePad + +from seg.models.data_preprocessor import VideoSegDataPreprocessor +from seg.models.utils import NO_OBJ + +with read_base(): + from ..._base_.default_runtime import * + from ..._base_.datasets.davis import * + from ..._base_.schedules.schedule_12e import * + +batch_augments = [ + dict( + type=BatchFixedSizePad, + size=(image_size[1], image_size[0]), + img_pad_value=0, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=NO_OBJ + ) +] +data_preprocessor = dict( + type=VideoSegDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=NO_OBJ, + batch_augments=batch_augments +) + +num_things_classes = 80 +num_stuff_classes = 0 +num_classes = num_things_classes + num_stuff_classes + +ov_datasets_name = 'CocoOVDataset' diff --git a/seg/configs/m2ov_val/datasets/vipseg.py b/seg/configs/m2ov_val/datasets/vipseg.py new file mode 100644 index 0000000000000000000000000000000000000000..3080943057d5af54ac63a7fc6226db72916b9c3a --- /dev/null +++ b/seg/configs/m2ov_val/datasets/vipseg.py @@ -0,0 +1,43 @@ +from mmengine.config import read_base + +from mmdet.models import BatchFixedSizePad + +from seg.models.data_preprocessor import VideoSegDataPreprocessor + +with read_base(): + from ..._base_.default_runtime import * + from ..._base_.datasets.vipseg import * + from ..._base_.schedules.schedule_12e import * + +batch_augments = [ + dict( + type=BatchFixedSizePad, + size=(image_size[1], image_size[0]), + img_pad_value=0, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255 + ) +] +data_preprocessor = dict( + type=VideoSegDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255, + batch_augments=batch_augments +) + +num_things_classes = 58 +num_stuff_classes = 66 +num_classes = num_things_classes + num_stuff_classes + +ov_datasets_name = 'VIPSegDataset' +default_hooks.update( + logger=dict(type=LoggerHook, interval=1), +) diff --git a/seg/configs/m2ov_val/datasets/y19.py b/seg/configs/m2ov_val/datasets/y19.py new file mode 100644 index 0000000000000000000000000000000000000000..de739d5865339cf36d26b52e10ce325f54898e3e --- /dev/null +++ b/seg/configs/m2ov_val/datasets/y19.py @@ -0,0 +1,43 @@ +from mmengine.config import read_base + +from mmdet.models import BatchFixedSizePad + +from seg.models.data_preprocessor import VideoSegDataPreprocessor + +with read_base(): + from ..._base_.default_runtime import * + from ..._base_.datasets.youtube_vis_2019 import * + from ..._base_.schedules.schedule_12e import * + +batch_augments = [ + dict( + type=BatchFixedSizePad, + size=(image_size[1], image_size[0]), + img_pad_value=0, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255 + ) +] +data_preprocessor = dict( + type=VideoSegDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255, + batch_augments=batch_augments +) + +num_things_classes = 40 +num_stuff_classes = 0 +num_classes = num_things_classes + num_stuff_classes + +ov_datasets_name = 'YouTubeVISDataset_2019' +default_hooks.update( + logger=dict(type=LoggerHook, interval=1), +) diff --git a/seg/configs/m2ov_val/datasets/y21.py b/seg/configs/m2ov_val/datasets/y21.py new file mode 100644 index 0000000000000000000000000000000000000000..29fdd1307ba18ba0019b78bb82d3fbd4436ceb1d --- /dev/null +++ b/seg/configs/m2ov_val/datasets/y21.py @@ -0,0 +1,45 @@ +from mmengine.config import read_base + +from mmdet.models import BatchFixedSizePad + +from seg.models.data_preprocessor import VideoSegDataPreprocessor + +with read_base(): + from ..._base_.default_runtime import * + from ..._base_.datasets.youtube_vis_2021 import * + from ..._base_.schedules.schedule_12e import * + from ..._base_.datasets.joint_dataset import train_dataloader as training_loader + + +batch_augments = [ + dict( + type=BatchFixedSizePad, + size=(image_size[1], image_size[0]), + img_pad_value=0, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255 + ) +] +data_preprocessor = dict( + type=VideoSegDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + pad_mask=True, + mask_pad_value=0, + pad_seg=True, + seg_pad_value=255, + batch_augments=batch_augments +) + +num_things_classes = 40 +num_stuff_classes = 0 +num_classes = num_things_classes + num_stuff_classes + +ov_datasets_name = 'YouTubeVISDataset_2021' +default_hooks.update( + logger=dict(type=LoggerHook, interval=1), +) diff --git a/seg/configs/m2ov_val/eval_m2_convl_300q_ov_ade.py b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_ade.py new file mode 100644 index 0000000000000000000000000000000000000000..c357d4582a79bee968d5f1b1a45f088cf7f9880d --- /dev/null +++ b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_ade.py @@ -0,0 +1,27 @@ +from mmengine import read_base + +with read_base(): + from .datasets.ade import * + from .models.m2_convl_300q import * + +model.update( + data_preprocessor=data_preprocessor, + panoptic_head=dict( + ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}', + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + panoptic_fusion_head=dict( + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + test_cfg=dict( + panoptic_on=True, + semantic_on=False, + instance_on=False, + ), +) +overlapping = dict( + train=training_loader.dataset, + test=test_dataloader.dataset +) diff --git a/seg/configs/m2ov_val/eval_m2_convl_300q_ov_cityscapes.py b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..f98003417db7bc50fb3d53708d19177dfc9729be --- /dev/null +++ b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_cityscapes.py @@ -0,0 +1,23 @@ +from mmengine import read_base + +with read_base(): + from .datasets.cityscapes import * + from .models.m2_convl_300q import * + +model.update( + data_preprocessor=data_preprocessor, + panoptic_head=dict( + ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}', + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + panoptic_fusion_head=dict( + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + test_cfg=dict( + panoptic_on=True, + semantic_on=False, + instance_on=False, + ), +) diff --git a/seg/configs/m2ov_val/eval_m2_convl_300q_ov_coco.py b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..167d07e7e9037a8382110147f42ef0aa73f567df --- /dev/null +++ b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_coco.py @@ -0,0 +1,23 @@ +from mmengine import read_base + +with read_base(): + from .datasets.coco import * + from .models.m2_convl_300q import * + +model.update( + data_preprocessor=data_preprocessor, + panoptic_head=dict( + ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}', + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + panoptic_fusion_head=dict( + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + test_cfg=dict( + panoptic_on=True, + semantic_on=False, + instance_on=True, + ), +) diff --git a/seg/configs/m2ov_val/eval_m2_convl_300q_ov_davis.py b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_davis.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a26c41dc3b23d0bbb73eb940872aaee35755d3 --- /dev/null +++ b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_davis.py @@ -0,0 +1,33 @@ +from mmengine import read_base + +from seg.models.detectors import Mask2formerVideoMinVIS + +with read_base(): + from .datasets.davis import * + from .models.m2_convl_300q import * + +model.update( + data_preprocessor=data_preprocessor, + type=Mask2formerVideoMinVIS, + clip_size=5, + clip_size_small=3, + whole_clip_thr=0, + small_clip_thr=15, + overlap=0, + panoptic_head=dict( + ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}', + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + panoptic_fusion_head=dict( + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + test_cfg=dict( + panoptic_on=False, + semantic_on=False, + instance_on=False, + proposal_on=True, + num_proposals=25, + ), +) diff --git a/seg/configs/m2ov_val/eval_m2_convl_300q_ov_mose.py b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_mose.py new file mode 100644 index 0000000000000000000000000000000000000000..c19d709669cb59670c8680690701b3639bdbe4b6 --- /dev/null +++ b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_mose.py @@ -0,0 +1,33 @@ +from mmengine import read_base + +from seg.models.detectors import Mask2formerVideoMinVIS + +with read_base(): + from .datasets.mose import * + from .models.m2_convl_300q import * + +model.update( + data_preprocessor=data_preprocessor, + type=Mask2formerVideoMinVIS, + clip_size=5, + clip_size_small=3, + whole_clip_thr=0, + small_clip_thr=15, + overlap=0, + panoptic_head=dict( + ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}', + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + panoptic_fusion_head=dict( + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + test_cfg=dict( + panoptic_on=False, + semantic_on=False, + instance_on=False, + proposal_on=True, + num_proposals=25, + ), +) diff --git a/seg/configs/m2ov_val/eval_m2_convl_300q_ov_vipseg.py b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_vipseg.py new file mode 100644 index 0000000000000000000000000000000000000000..c25e2b80d7e5db0b661db0c2ad9d8ff1571922e1 --- /dev/null +++ b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_vipseg.py @@ -0,0 +1,38 @@ +from mmengine import read_base + +from seg.models.detectors import Mask2formerVideoMinVIS + +with read_base(): + from .datasets.vipseg import * + from .models.m2_convl_300q import * + +model.update( + data_preprocessor=data_preprocessor, + type=Mask2formerVideoMinVIS, + clip_size=2, + clip_size_small=3, + whole_clip_thr=0, + small_clip_thr=15, + overlap=0, + panoptic_head=dict( + ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}', + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + panoptic_fusion_head=dict( + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + test_cfg=dict( + panoptic_on=True, + semantic_on=False, + instance_on=False, + ), +) + +val_evaluator = dict( + type=VIPSegMetric, + metric=['VPQ@1', 'VPQ@2', 'VPQ@4', 'VPQ@6'], + format_only=True, +) +test_evaluator = val_evaluator diff --git a/seg/configs/m2ov_val/eval_m2_convl_300q_ov_y19.py b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_y19.py new file mode 100644 index 0000000000000000000000000000000000000000..556a80bded9e162ee209b44c244a3beec598b47f --- /dev/null +++ b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_y19.py @@ -0,0 +1,31 @@ +from mmengine import read_base + +from seg.models.detectors import Mask2formerVideoMinVIS + +with read_base(): + from .datasets.y19 import * + from .models.m2_convl_300q import * + +model.update( + data_preprocessor=data_preprocessor, + type=Mask2formerVideoMinVIS, + clip_size=5, + clip_size_small=3, + whole_clip_thr=0, + small_clip_thr=15, + overlap=0, + panoptic_head=dict( + ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}', + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + panoptic_fusion_head=dict( + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + test_cfg=dict( + panoptic_on=False, + semantic_on=False, + instance_on=True, + ), +) diff --git a/seg/configs/m2ov_val/eval_m2_convl_300q_ov_y21.py b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_y21.py new file mode 100644 index 0000000000000000000000000000000000000000..fb20214a31ca81895c6bafa6dfdb31f3a7933a7a --- /dev/null +++ b/seg/configs/m2ov_val/eval_m2_convl_300q_ov_y21.py @@ -0,0 +1,31 @@ +from mmengine import read_base + +from seg.models.detectors import Mask2formerVideoMinVIS + +with read_base(): + from .datasets.y21 import * + from .models.m2_convl_300q import * + +model.update( + data_preprocessor=data_preprocessor, + type=Mask2formerVideoMinVIS, + clip_size=5, + clip_size_small=3, + whole_clip_thr=0, + small_clip_thr=15, + overlap=0, + panoptic_head=dict( + ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}', + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + panoptic_fusion_head=dict( + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + test_cfg=dict( + panoptic_on=False, + semantic_on=False, + instance_on=True, + ), +) diff --git a/seg/configs/m2ov_val/eval_m2_convl_ov_coco_pan_point.py b/seg/configs/m2ov_val/eval_m2_convl_ov_coco_pan_point.py new file mode 100644 index 0000000000000000000000000000000000000000..c434759a7b368abbca292323ca665f1d601d4d44 --- /dev/null +++ b/seg/configs/m2ov_val/eval_m2_convl_ov_coco_pan_point.py @@ -0,0 +1,25 @@ +from mmengine import read_base + +with read_base(): + from .datasets.coco_pan_point import * + from .models.m2_convl_300q import * + +model.update( + data_preprocessor=data_preprocessor, + inference_sam=True, + panoptic_head=dict( + enable_box_query=True, + ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}', + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + panoptic_fusion_head=dict( + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + ), + test_cfg=dict( + panoptic_on=False, + semantic_on=False, + instance_on=True, + ), +) diff --git a/seg/configs/m2ov_val/models/m2_convl_300q.py b/seg/configs/m2ov_val/models/m2_convl_300q.py new file mode 100644 index 0000000000000000000000000000000000000000..c5858743eb1399fc4619443d890d33bd779d59bc --- /dev/null +++ b/seg/configs/m2ov_val/models/m2_convl_300q.py @@ -0,0 +1,151 @@ +from torch.nn import GroupNorm, ReLU + +from mmdet.models import MSDeformAttnPixelDecoder, CrossEntropyLoss, DiceLoss, FocalLoss +from mmdet.models.task_modules.assigners import HungarianAssigner, ClassificationCost, CrossEntropyLossCost, DiceCost +from mmdet.models.task_modules.samplers import MaskPseudoSampler + +from seg.models.detectors import Mask2formerVideo +from seg.models.fusion_head import OMGFusionHead +from seg.models.heads import Mask2FormerVideoHead +from seg.models.backbones import OpenCLIPBackbone + +model = dict( + type=Mask2formerVideo, + data_preprocessor=None, # to fill + backbone=dict( + type=OpenCLIPBackbone, + model_name='convnext_large_d_320', + fix=True, + init_cfg=dict( + type='clip_pretrain', + checkpoint='laion2b_s29b_b131k_ft_soup' + ) + ), + panoptic_head=dict( + init_cfg=dict( + type='Pretrained', + checkpoint='./models/m2_convl_12e.pth', + prefix='panoptic_head.' + ), + type=Mask2FormerVideoHead, + sphere_cls=True, + ov_classifier_name=None, + logit=None, + in_channels=[192, 384, 768, 1536], # pass to pixel_decoder inside + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_things_classes=0, + num_stuff_classes=0, + num_queries=300, + num_transformer_feat_level=3, + pixel_decoder=dict( + type=MSDeformAttnPixelDecoder, + num_outs=3, + norm_cfg=dict(type=GroupNorm, num_groups=32), + act_cfg=dict(type=ReLU), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type=ReLU, inplace=True)))), + positional_encoding=dict(num_feats=128, normalize=True)), + enforce_decoder_input_project=False, + positional_encoding=dict(num_feats=128, normalize=True), + transformer_decoder=dict( # Mask2FormerTransformerDecoder + return_intermediate=True, + num_layers=9, + layer_cfg=dict( # Mask2FormerTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True))), + init_cfg=None), + loss_cls=dict( + type=CrossEntropyLoss, + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=None # [1.0] * num_classes + [0.1] + ), + loss_mask=dict( + type=CrossEntropyLoss, + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type=DiceLoss, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0), + loss_iou=dict( + type=FocalLoss, + use_sigmoid=True, + loss_weight=2.0, + reduction='mean' + ) + ), + panoptic_fusion_head=dict( + type=OMGFusionHead, + num_things_classes=0, + num_stuff_classes=0, + loss_panoptic=None, + init_cfg=None + ), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type=HungarianAssigner, + match_costs=[ + dict(type=ClassificationCost, weight=2.0), + dict( + type=CrossEntropyLossCost, weight=5.0, use_sigmoid=True), + dict(type=DiceCost, weight=5.0, pred_act=True, eps=1.0) + ]), + sampler=dict(type=MaskPseudoSampler)), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True, + object_mask_thr=0., + ), + init_cfg=None +) + +ov_model_name = 'convnext_large_d_320' diff --git a/seg/datasets/ade_ov.py b/seg/datasets/ade_ov.py new file mode 100644 index 0000000000000000000000000000000000000000..234f80bc9785c5dd19965135753cc52a8d0c1d10 --- /dev/null +++ b/seg/datasets/ade_ov.py @@ -0,0 +1,370 @@ +import copy +from typing import List + +from mmdet.registry import DATASETS +from mmdet.datasets.coco_panoptic import CocoPanopticDataset +from mmengine import get_local_path + + +@DATASETS.register_module() +class ADEPanopticOVDataset(CocoPanopticDataset): + """ADE Open Vocabulary dataset for Panoptic segmentation. + The class names are changed. + """ + + METAINFO = { + 'classes': + ( + 'bed,beds', + 'windowpane,window,windows', + 'cabinet,cabinets,wall mounted cabine', + 'person,child,girl,boy,woman,man,people,children,girls,boys,women,men', + 'door,double door,doors', + 'table,tables,tablecloth', + 'curtain,drape,drapery,mantle,pall', + 'chair,chairs', + 'car,automobile,cars', + 'painting,picture,paintings,pictures,wallart,framed canvas', + 'sofa,couch,sofas,couches', + 'shelf,shelves', + 'mirror,mirrors', + 'armchair,armchairs', + 'seat,seats', + 'fence,fencing', + 'desk,desks', + 'wardrobe,closet,press,wardrobes,closets', + 'lamp,lamps', + 'bathtub,bathing tub,bath,tub', + 'railing,rail', + 'cushion,cushions', + 'box,boxes', + 'column,pillar', + 'signboard,sign,signboards,signs', + 'chest of drawers,chest,bureau,dresser', + 'counter', + 'sink', + 'fireplace,hearth,open fireplace', + 'refrigerator,icebox', + 'stairs,steps', + 'case,display case,showcase,vitrine', + 'pool table,billiard table,snooker table', + 'pillow,pillows', + 'screen door,shower door', + 'bookcase', + 'coffee table,cocktail table', + 'toilet,commode,crapper,potty', + 'flower,flowers', + 'book,books', + 'bench,benches', + 'countertop,counter top,worktop', + 'stove,kitchen stove,kitchen range,kitchen range,cooking stove', + 'palm tree,palm trees', + 'kitchen island', + 'computer,computing machine,computing device,data processor,electronic computer,information processing system', + 'swivel chair', + 'boat', + 'arcade machine,arcade machines', + 'bus,autobus,double-decker,jitney,motorbus,motorcoach,omnibus,passenger vehicle', + 'towel', + 'light bulb,lightbulb,bulb,incandescent lamp,electric light,electric-light bulb', + 'truck,motortruck', + 'chandelier,pendant,pendent', + 'awning,sunshade,sunblind', + 'streetlight,street lamp', + 'booth,cubicle,stall,kiosk', + 'television receiver,television,television set,tv,tv set', + 'airplane,aeroplane,airplanes,aeroplanes', + 'apparel,wearing apparel,dress,clothes', + 'pole', + 'bannister,banister,balustrade,balusters,handrail', + 'ottoman,pouf,pouffe,puff,hassock', + 'bottle,bottles,water bottle', + 'van', + 'ship', + 'fountain', + 'washer,automatic washer,washing machine', + 'plaything,toy,toys', + 'stool,stools', + 'barrel,cask,barrels,casks', + 'basket,handbasket', + 'bag,bags,gift bag,paper bag', + 'minibike,motorbike', + 'oven', + 'ball,balls', + 'food,solid food', + 'step,stair', + 'trade name,brand name,brand,marque', + 'microwave,microwave oven', + 'plant pots,plant pot,flower pot,flowerpot,planter', + 'animal,animate being,dog,cat,horse,cow,sheep,zebra,girraffe,bird', + 'bicycle,bike', + 'dishwasher,dish washer,dishwashing machine', + 'projection screen', + 'sculpture,sculptures', + 'exhaust hood', + 'sconce,sconce lamp,sconce light', + 'vase,vases', + 'traffic light,traffic signal,traffic lights', + 'tray,trays', + 'ashcan,trash can,garbage can,wastebin,ash bin,ash-bin,ashbin,dustbin,trash barrel,trash bin', + 'ceiling fan,floor fan', + 'plate,plates', + 'monitor,monitoring device,monitors', + 'bulletin board,notice board', + 'radiator', + 'cup,cups,drinking glass,drinking glasses', + 'clock', + 'flag,flags', + + 'wall,walls,brick wall,stone wall,interior wall', + 'building,buildings,edifice,edifices', + 'sky,clouds', + 'floor,flooring', + 'tree,trees', + 'ceiling', + 'road,route,street,roads,streets,routes', + 'grass,grass field', + 'sidewalk,pavement', + 'earth,ground', + 'mountain,mount,mountains', + 'plant,flora,plant life,plants,bushes', + 'water', + 'house exterior', + 'sea,ocean', + 'rug,carpet,carpeting', + 'field', + 'rock,stone,rocks,stones', + 'pedestal', + 'sand', + 'skyscraper,skyscrapers', + 'grandstand,covered stand', + 'path', + 'runway', + 'stairway,staircase', + 'river', + 'bridge,span', + 'window screen,door screen', + 'hill', + 'bar', + 'hovel,hut,hutch,shack,shanty', + 'tower,towers', + 'dirt track', + 'land,soil', + 'escalator,moving staircase,moving stairway', + 'buffet,sideboard', + 'poster,posting,placard,notice,bill,card', + 'stage', + 'conveyer belt,conveyor belt,conveyer,conveyor,transporter', + 'canopy', + 'swimming pool,swimming bath', + 'waterfall,falls', + 'tent,collapsible shelter', + 'cradle', + 'tank,storage tank', + 'lake', + 'blanket,cover', + 'pier,wharf,wharfage,dock', + 'crt screen', + 'shower', + ), + 'thing_classes': + ( + 'bed,beds', + 'windowpane,window,windows', + 'cabinet,cabinets,wall mounted cabine', + 'person,child,girl,boy,woman,man,people,children,girls,boys,women,men', + 'door,double door,doors', + 'table,tables,tablecloth', + 'curtain,drape,drapery,mantle,pall', + 'chair,chairs', + 'car,automobile,cars', + 'painting,picture,paintings,pictures,wallart,framed canvas', + 'sofa,couch,sofas,couches', + 'shelf,shelves', + 'mirror,mirrors', + 'armchair,armchairs', + 'seat,seats', + 'fence,fencing', + 'desk,desks', + 'wardrobe,closet,press,wardrobes,closets', + 'lamp,lamps', + 'bathtub,bathing tub,bath,tub', + 'railing,rail', + 'cushion,cushions', + 'box,boxes', + 'column,pillar', + 'signboard,sign,signboards,signs', + 'chest of drawers,chest,bureau,dresser', + 'counter', + 'sink', + 'fireplace,hearth,open fireplace', + 'refrigerator,icebox', + 'stairs,steps', + 'case,display case,showcase,vitrine', + 'pool table,billiard table,snooker table', + 'pillow,pillows', + 'screen door,shower door', + 'bookcase', + 'coffee table,cocktail table', + 'toilet,commode,crapper,potty', + 'flower,flowers', + 'book,books', + 'bench,benches', + 'countertop,counter top,worktop', + 'stove,kitchen stove,kitchen range,kitchen range,cooking stove', + 'palm tree,palm trees', + 'kitchen island', + 'computer,computing machine,computing device,data processor,electronic computer,information processing system', + 'swivel chair', + 'boat', + 'arcade machine,arcade machines', + 'bus,autobus,double-decker,jitney,motorbus,motorcoach,omnibus,passenger vehicle', + 'towel', + 'light bulb,lightbulb,bulb,incandescent lamp,electric light,electric-light bulb', + 'truck,motortruck', + 'chandelier,pendant,pendent', + 'awning,sunshade,sunblind', + 'streetlight,street lamp', + 'booth,cubicle,stall,kiosk', + 'television receiver,television,television set,tv,tv set', + 'airplane,aeroplane,airplanes,aeroplanes', + 'apparel,wearing apparel,dress,clothes', + 'pole', + 'bannister,banister,balustrade,balusters,handrail', + 'ottoman,pouf,pouffe,puff,hassock', + 'bottle,bottles,water bottle', + 'van', + 'ship', + 'fountain', + 'washer,automatic washer,washing machine', + 'plaything,toy,toys', + 'stool,stools', + 'barrel,cask,barrels,casks', + 'basket,handbasket', + 'bag,bags,gift bag,paper bag', + 'minibike,motorbike', + 'oven', + 'ball,balls', + 'food,solid food', + 'step,stair', + 'trade name,brand name,brand,marque', + 'microwave,microwave oven', + 'plant pots,plant pot,flower pot,flowerpot,planter', + 'animal,animate being,dog,cat,horse,cow,sheep,zebra,girraffe,bird', + 'bicycle,bike', + 'dishwasher,dish washer,dishwashing machine', + 'projection screen', + 'sculpture,sculptures', + 'exhaust hood', + 'sconce,sconce lamp,sconce light', + 'vase,vases', + 'traffic light,traffic signal,traffic lights', + 'tray,trays', + 'ashcan,trash can,garbage can,wastebin,ash bin,ash-bin,ashbin,dustbin,trash barrel,trash bin', + 'ceiling fan,floor fan', + 'plate,plates', + 'monitor,monitoring device,monitors', + 'bulletin board,notice board', + 'radiator', + 'cup,cups,drinking glass,drinking glasses', + 'clock', + 'flag,flags', + ), + 'stuff_classes': + ( + 'wall,walls,brick wall,stone wall,interior wall', + 'building,buildings,edifice,edifices', + 'sky,clouds', + 'floor,flooring', + 'tree,trees', + 'ceiling', + 'road,route,street,roads,streets,routes', + 'grass,grass field', + 'sidewalk,pavement', + 'earth,ground', + 'mountain,mount,mountains', + 'plant,flora,plant life,plants,bushes', + 'water', + 'house exterior', + 'sea,ocean', + 'rug,carpet,carpeting', + 'field', + 'rock,stone,rocks,stones', + 'pedestal', + 'sand', + 'skyscraper,skyscrapers', + 'grandstand,covered stand', + 'path', + 'runway', + 'stairway,staircase', + 'river', + 'bridge,span', + 'window screen,door screen', + 'hill', + 'bar', + 'hovel,hut,hutch,shack,shanty', + 'tower,towers', + 'dirt track', + 'land,soil', + 'escalator,moving staircase,moving stairway', + 'buffet,sideboard', + 'poster,posting,placard,notice,bill,card', + 'stage', + 'conveyer belt,conveyor belt,conveyer,conveyor,transporter', + 'canopy', + 'swimming pool,swimming bath', + 'waterfall,falls', + 'tent,collapsible shelter', + 'cradle', + 'tank,storage tank', + 'lake', + 'blanket,cover', + 'pier,wharf,wharfage,dock', + 'crt screen', + 'shower', + ), + } + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file`` + + Returns: + List[dict]: A list of annotation. + """ # noqa: E501 + with get_local_path( + self.ann_file, backend_args=self.backend_args) as local_path: + self.coco = self.COCOAPI(local_path) + for idx, name in enumerate(self.metainfo['classes']): + if not (self.coco.cats[idx]['name'].strip() in name.split(',')): + print(f"Warning {idx} !!:\n{self.coco.cats[idx]['name']} vs {name}") + # use all classes, cannot use self.metainfo anymore. + self.cat_ids = self.coco.get_cat_ids() + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) + + img_ids = self.coco.get_img_ids() + data_list = [] + total_ann_ids = [] + for img_id in img_ids: + raw_img_info = self.coco.load_imgs([img_id])[0] + raw_img_info['img_id'] = img_id + + ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) + raw_ann_info = self.coco.load_anns(ann_ids) + total_ann_ids.extend(ann_ids) + + parsed_data_info = self.parse_data_info({ + 'raw_ann_info': + raw_ann_info, + 'raw_img_info': + raw_img_info + }) + data_list.append(parsed_data_info) + if self.ANN_ID_UNIQUE: + assert len(set(total_ann_ids)) == len( + total_ann_ids + ), f"Annotation ids in '{self.ann_file}' are not unique!" + + del self.coco + + return data_list diff --git a/seg/datasets/cityscapes.py b/seg/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..b97a248e9abb61625ebbe73240e6fabd87f8f723 --- /dev/null +++ b/seg/datasets/cityscapes.py @@ -0,0 +1,103 @@ +from mmdet.registry import DATASETS +from mmdet.datasets.coco_panoptic import CocoPanopticDataset +import os.path as osp + +@DATASETS.register_module() +class CityscapesPanopticDataset(CocoPanopticDataset): + """Cityscapes dataset for Panoptic segmentation. + The class names are changed. + """ + + METAINFO = { + 'classes': + ( + 'person', 'rider', 'car', 'truck', 'bus', + 'train', 'motorcycle', 'bicycle', + + 'road', 'sidewalk', 'building', 'wall', 'fence', + 'pole', 'traffic light', 'traffic sign', 'vegetation', + 'terrain', 'sky' + ), + 'thing_classes': + ( + 'person', 'rider', 'car', 'truck', 'bus', + 'train', 'motorcycle', 'bicycle' + ), + 'stuff_classes': + ( + 'road', 'sidewalk', 'building', 'wall', 'fence', + 'pole', 'traffic light', 'traffic sign', 'vegetation', + 'terrain', 'sky' + ), + } + + def parse_data_info(self, raw_data_info: dict) -> dict: + """Parse raw annotation to target format. + + Args: + raw_data_info (dict): Raw data information load from ``ann_file``. + + Returns: + dict: Parsed annotation. + """ + img_info = raw_data_info['raw_img_info'] + ann_info = raw_data_info['raw_ann_info'] + # filter out unmatched annotations which have + # same segment_id but belong to other image + ann_info = [ + ann for ann in ann_info if ann['image_id'] == img_info['img_id'] + ] + data_info = {} + + img_path = osp.join(self.data_prefix['img'], img_info['file_name']) + if self.data_prefix.get('seg', None): + seg_map_path = osp.join( + self.data_prefix['seg'], + img_info['file_name'].replace("_leftImg8bit.png", "_panoptic.png") + ) + else: + seg_map_path = None + data_info['img_path'] = img_path + data_info['img_id'] = img_info['img_id'] + data_info['seg_map_path'] = seg_map_path + data_info['height'] = img_info['height'] + data_info['width'] = img_info['width'] + + if self.return_classes: + data_info['text'] = self.metainfo['thing_classes'] + data_info['stuff_text'] = self.metainfo['stuff_classes'] + data_info['custom_entities'] = True # no important + + instances = [] + segments_info = [] + for ann in ann_info: + instance = {} + x1, y1, w, h = ann['bbox'] + if ann['area'] <= 0 or w < 1 or h < 1: + continue + bbox = [x1, y1, x1 + w, y1 + h] + category_id = ann['category_id'] + contiguous_cat_id = self.cat2label[category_id] + + is_thing = self.coco.load_cats(ids=category_id)[0]['isthing'] + if is_thing: + is_crowd = ann.get('iscrowd', False) + instance['bbox'] = bbox + instance['bbox_label'] = contiguous_cat_id + if not is_crowd: + instance['ignore_flag'] = 0 + else: + instance['ignore_flag'] = 1 + is_thing = False + + segment_info = { + 'id': ann['id'], + 'category': contiguous_cat_id, + 'is_thing': is_thing + } + segments_info.append(segment_info) + if len(instance) > 0 and is_thing: + instances.append(instance) + data_info['instances'] = instances + data_info['segments_info'] = segments_info + return data_info diff --git a/seg/datasets/coco_ins_ov.py b/seg/datasets/coco_ins_ov.py new file mode 100644 index 0000000000000000000000000000000000000000..f7b8409ba3a6d1cdfb9ce929e2b1dd0aaa989d31 --- /dev/null +++ b/seg/datasets/coco_ins_ov.py @@ -0,0 +1,255 @@ +import copy +from typing import List + +from mmdet.registry import DATASETS +from mmdet.datasets.coco import CocoDataset +from mmengine import get_local_path, print_log + +CLASSES_ORIGINAL = ( + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', + 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', + 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', + 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', + 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush' +) + +CLASSES_48 = ( + 'person', 'bicycle', 'car', 'motorcycle', 'truck', 'boat', 'bench', + 'bird', 'horse', 'sheep', 'zebra', 'giraffe', 'backpack', + 'handbag', 'skis', 'kite', 'surfboard', 'bottle', 'spoon', + 'bowl', 'banana', 'apple', 'orange', 'broccoli', 'carrot', + 'pizza', 'donut', 'chair', 'bed', 'tv', 'laptop', + 'remote', 'microwave', 'oven', 'refrigerator', 'book', + 'clock', 'vase', 'toothbrush', 'train', 'bear', 'suitcase', + 'frisbee', 'fork', 'sandwich', 'toilet', 'mouse', 'toaster' +) + +CLASSES_17 = ( + 'bus', 'dog', 'cow', 'elephant', 'umbrella', 'tie', + 'skateboard', 'cup', 'knife', 'cake', + 'couch', 'keyboard', 'sink', 'scissors', + 'airplane', 'cat', 'snowboard' +) + +CLASSES_IDS_48 = [0, 1, 2, 3, 7, 8, 13, 14, 17, 18, 22, 23, 24, 26, 30, 33, 37, 39, 44, 45, 46, 47, 49, 50, 51, 53, 54, + 56, 59, 62, 63, 65, 68, 69, 72, 73, 74, 75, 79, 6, 21, 28, 29, 42, 48, 61, 64, 70] +CLASSES_IDS_17 = [5, 16, 19, 20, 25, 27, 36, 41, 43, 55, 57, 66, 71, 76, 4, 15, 31] + + +@DATASETS.register_module() +class CocoOVDataset(CocoDataset): + """Coco Open Vocabulary dataset for Instance segmentation. + The class names are changed. + """ + METAINFO = { + 'classes': + ('person,child,girl,boy,woman,man,people,children,girls,boys,women,men,lady,guy,ladies,guys,clothes', + 'bicycle,bicycles,bike,bikes', + 'car,cars', + 'motorcycle,motorcycles', + 'airplane,airplanes', + 'bus,buses', + 'train,trains,locomotive,locomotives,freight train', + 'truck,trucks', + 'boat,boats', + 'traffic light', + 'fire hydrant', + 'stop sign', + 'parking meter', + 'bench,benches', + 'bird,birds', + 'cat,cats,kitties,kitty', + 'dog,dogs,puppy,puppies', + 'horse,horses,foal', + 'sheep', + 'cow,cows,calf', + 'elephant,elephants', + 'bear,bears', + 'zebra,zebras', + 'giraffe,giraffes', + 'backpack,backpacks', + 'umbrella,umbrellas', + 'handbag,handbags', + 'tie', + 'suitcase,suitcases', + 'frisbee', + 'skis', + 'snowboard', + 'sports ball', + 'kite,kites', + 'baseball bat', + 'baseball glove', + 'skateboard', + 'surfboard', + 'tennis racket', + 'bottle,bottles,water bottle', + 'wine glass,wine glasses,wineglass', + 'cup,cups,water cup,water glass', + 'fork,forks', + 'knife,knives', + 'spoon,spoons', + 'bowl,bowls', + 'banana,bananas', + 'apple,apples,apple fruit', + 'sandwich,sandwiches', + 'orange fruit', + 'broccoli', + 'carrot,carrots', + 'hot dog', + 'pizza', + 'donut,donuts', + 'cake,cakes', + 'chair,chairs', + 'couch,sofa,sofas', + 'potted plant,potted plants,pottedplant,pottedplants,planter,planters', + 'bed,beds', + 'dining table,dining tables,diningtable,diningtables,plate,plates,diningtable tablecloth', + 'toilet', + 'tv', + 'laptop', + 'mouse', + 'tv remote,remote control', + 'keyboard', + 'cell phone,mobile', + 'microwave', + 'oven,ovens', + 'toaster', + 'sink,sinks', + 'refrigerator,fridge', + 'book,books', + 'clock', + 'vase,vases', + 'scissor,scissors', + 'teddy bear,teddy bears', + 'hair drier', + 'toothbrush,toothbrushes', + ), + + 'palette': + [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), + (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), + (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), + (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255), + (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255), + (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92), + (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164), + (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0), + (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174), + (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54), + (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51), + (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65), + (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0), + (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161), + (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120), + (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133), + (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62), + (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45), + (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1), + (246, 0, 122), (191, 162, 208)] + } + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file`` + + Returns: + List[dict]: A list of annotation. + """ # noqa: E501 + with get_local_path( + self.ann_file, backend_args=self.backend_args) as local_path: + self.coco = self.COCOAPI(local_path) + # The order of returned `cat_ids` will not + # change with the order of the `classes` + self.cat_ids = self.coco.get_cat_ids( + cat_names=CLASSES_ORIGINAL) + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) + + img_ids = self.coco.get_img_ids() + data_list = [] + total_ann_ids = [] + for img_id in img_ids: + raw_img_info = self.coco.load_imgs([img_id])[0] + raw_img_info['img_id'] = img_id + + ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) + raw_ann_info = self.coco.load_anns(ann_ids) + total_ann_ids.extend(ann_ids) + + parsed_data_info = self.parse_data_info({ + 'raw_ann_info': + raw_ann_info, + 'raw_img_info': + raw_img_info + }) + data_list.append(parsed_data_info) + if self.ANN_ID_UNIQUE: + assert len(set(total_ann_ids)) == len( + total_ann_ids + ), f"Annotation ids in '{self.ann_file}' are not unique!" + + del self.coco + + return data_list + + def filter_data(self) -> List[dict]: + valid_data_infos = super().filter_data() + + if self.filter_cfg is None: + return valid_data_infos + + sub_split = self.filter_cfg.get('sub_split', None) + if sub_split is None: + return valid_data_infos + + if sub_split == '48_17': + with_cat_ids = [] + wo_cat_ids = [] + classes = list(CLASSES_ORIGINAL) + if self.test_mode: + for cls in CLASSES_17: + with_cat_ids.append(classes.index(cls)) + for cls in CLASSES_48: + with_cat_ids.append(classes.index(cls)) + else: + for cls in CLASSES_48: + with_cat_ids.append(classes.index(cls)) + for cls in CLASSES_17: + wo_cat_ids.append(classes.index(cls)) + else: + raise ValueError(f"{sub_split} does not support") + + keep_w_novel = True + filtered_data_infos = [] + for data_info in valid_data_infos: + instances = data_info['instances'] + filtered_instances = [] + flag = False + for ins in instances: + if ins['bbox_label'] in with_cat_ids: + filtered_instances.append(ins) + flag = True + if not flag: + continue + if not keep_w_novel: + for ins in instances: + if ins['bbox_label'] in wo_cat_ids: + filtered_instances.append(ins) + flag = False + break + if flag: + data_info['instances'] = filtered_instances + filtered_data_infos.append(data_info) + + print_log( + f"There are totally {len(filtered_data_infos)} images in the filtered dataset.", + logger='current', + ) + return filtered_data_infos diff --git a/seg/datasets/coco_ov.py b/seg/datasets/coco_ov.py new file mode 100644 index 0000000000000000000000000000000000000000..e7120cf578fef6bf50a554f47f344ace258e4146 --- /dev/null +++ b/seg/datasets/coco_ov.py @@ -0,0 +1,363 @@ +import copy +from typing import List + +from mmdet.registry import DATASETS +from mmdet.datasets.coco_panoptic import CocoPanopticDataset +from mmengine import get_local_path + + +@DATASETS.register_module() +class CocoPanopticOVDataset(CocoPanopticDataset): + """Coco Open Vocabulary dataset for Panoptic segmentation. + The class names are changed. + """ + + METAINFO = { + 'classes': + ('person,child,girl,boy,woman,man,people,children,girls,boys,women,men,lady,guy,ladies,guys,clothes', + 'bicycle,bicycles,bike,bikes', + 'car,cars', + 'motorcycle,motorcycles', + 'airplane,airplanes', + 'bus,buses', + 'train,trains,locomotive,locomotives,freight train', + 'truck,trucks', + 'boat,boats', + 'traffic light', + 'fire hydrant', + 'stop sign', + 'parking meter', + 'bench,benches', + 'bird,birds', + 'cat,cats,kitties,kitty', + 'dog,dogs,puppy,puppies', + 'horse,horses,foal', + 'sheep', + 'cow,cows,calf', + 'elephant,elephants', + 'bear,bears', + 'zebra,zebras', + 'giraffe,giraffes', + 'backpack,backpacks', + 'umbrella,umbrellas', + 'handbag,handbags', + 'tie', + 'suitcase,suitcases', + 'frisbee', + 'skis', + 'snowboard', + 'sports ball', + 'kite,kites', + 'baseball bat', + 'baseball glove', + 'skateboard', + 'surfboard', + 'tennis racket', + 'bottle,bottles,water bottle', + 'wine glass,wine glasses,wineglass', + 'cup,cups,water cup,water glass', + 'fork,forks', + 'knife,knives', + 'spoon,spoons', + 'bowl,bowls', + 'banana,bananas', + 'apple,apples,apple fruit', + 'sandwich,sandwiches', + 'orange fruit', + 'broccoli', + 'carrot,carrots', + 'hot dog', + 'pizza', + 'donut,donuts', + 'cake,cakes', + 'chair,chairs', + 'couch,sofa,sofas', + 'potted plant,potted plants,pottedplant,pottedplants,planter,planters', + 'bed,beds', + 'dining table,dining tables,diningtable,diningtables,plate,plates,diningtable tablecloth', + 'toilet', + 'tv', + 'laptop', + 'mouse', + 'tv remote,remote control', + 'keyboard', + 'cell phone,mobile', + 'microwave', + 'oven,ovens', + 'toaster', + 'sink,sinks', + 'refrigerator,fridge', + 'book,books', + 'clock', + 'vase,vases', + 'scissor,scissors', + 'teddy bear,teddy bears', + 'hair drier', + 'toothbrush,toothbrushes', + 'banner,banners', + 'blanket,blankets', + 'bridge', + 'cardboard', + 'counter', + 'curtain,curtains', + 'door,doors', + 'wood floor', + 'flower,flowers', + 'fruit,fruits', + 'gravel', + 'house', + 'lamp,bulb,lamps,bulbs', + 'mirror', + 'tennis net', + 'pillow,pillows', + 'platform', + 'playingfield,tennis court,baseball field,soccer field,tennis field', + 'railroad', + 'river', + 'road', + 'roof', + 'sand', + 'sea,sea wave,wave,waves', + 'shelf', + 'snow', + 'stairs', + 'tent', + 'towel', + 'brick wall', + 'stone wall', + 'tile wall', + 'wood wall', + 'water', + 'window blind', + 'window', + 'tree,trees,palm tree,bushes', + 'fence,fences', + 'ceiling', + 'sky,clouds', + 'cabinet,cabinets', + 'table', + 'floor,flooring,tile floor', + 'pavement', + 'mountain,mountains', + 'grass', + 'dirt', + 'paper', + 'food', + 'building,buildings', + 'rock', + 'wall,walls', + 'rug', + ), + 'thing_classes': + ('person,child,girl,boy,woman,man,people,children,girls,boys,women,men,lady,guy,ladies,guys,clothes', + 'bicycle,bicycles,bike,bikes', + 'car,cars', + 'motorcycle,motorcycles', + 'airplane,airplanes', + 'bus,buses', + 'train,trains,locomotive,locomotives,freight train', + 'truck,trucks', + 'boat,boats', + 'traffic light', + 'fire hydrant', + 'stop sign', + 'parking meter', + 'bench,benches', + 'bird,birds', + 'cat,cats,kitties,kitty', + 'dog,dogs,puppy,puppies', + 'horse,horses,foal', + 'sheep', + 'cow,cows,calf', + 'elephant,elephants', + 'bear,bears', + 'zebra,zebras', + 'giraffe,giraffes', + 'backpack,backpacks', + 'umbrella,umbrellas', + 'handbag,handbags', + 'tie', + 'suitcase,suitcases', + 'frisbee', + 'skis', + 'snowboard', + 'sports ball', + 'kite,kites', + 'baseball bat', + 'baseball glove', + 'skateboard', + 'surfboard', + 'tennis racket', + 'bottle,bottles,water bottle', + 'wine glass,wine glasses,wineglass', + 'cup,cups,water cup,water glass', + 'fork,forks', + 'knife,knives', + 'spoon,spoons', + 'bowl,bowls', + 'banana,bananas', + 'apple,apples,apple fruit', + 'sandwich,sandwiches', + 'orange fruit', + 'broccoli', + 'carrot,carrots', + 'hot dog', + 'pizza', + 'donut,donuts', + 'cake,cakes', + 'chair,chairs', + 'couch,sofa,sofas', + 'potted plant,potted plants,pottedplant,pottedplants,planter,planters', + 'bed,beds', + 'dining table,dining tables,diningtable,diningtables,plate,plates,diningtable tablecloth', + 'toilet', + 'tv', + 'laptop', + 'mouse', + 'tv remote,remote control', + 'keyboard', + 'cell phone,mobile', + 'microwave', + 'oven,ovens', + 'toaster', + 'sink,sinks', + 'refrigerator,fridge', + 'book,books', + 'clock', + 'vase,vases', + 'scissor,scissors', + 'teddy bear,teddy bears', + 'hair drier', + 'toothbrush,toothbrushes', + ), + 'stuff_classes': + ('banner,banners', + 'blanket,blankets', + 'bridge', + 'cardboard', + 'counter', + 'curtain,curtains', + 'door,doors', + 'wood floor', + 'flower,flowers', + 'fruit,fruits', + 'gravel', + 'house', + 'lamp,bulb,lamps,bulbs', + 'mirror', + 'tennis net', + 'pillow,pillows', + 'platform', + 'playingfield,tennis court,baseball field,soccer field,tennis field', + 'railroad', + 'river', + 'road', + 'roof', + 'sand', + 'sea,sea wave,wave,waves', + 'shelf', + 'snow', + 'stairs', + 'tent', + 'towel', + 'brick wall', + 'stone wall', + 'tile wall', + 'wood wall', + 'water', + 'window blind', + 'window', + 'tree,trees,palm tree,bushes', + 'fence,fences', + 'ceiling', + 'sky,clouds', + 'cabinet,cabinets', + 'table', + 'floor,flooring,tile floor', + 'pavement', + 'mountain,mountains', + 'grass', + 'dirt', + 'paper', + 'food', + 'building,buildings', + 'rock', + 'wall,walls', + 'rug' + ), + 'palette': + [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), + (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), + (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), + (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255), + (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255), + (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92), + (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164), + (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0), + (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174), + (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54), + (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51), + (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65), + (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0), + (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161), + (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120), + (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133), + (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62), + (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45), + (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1), + (246, 0, 122), (191, 162, 208), (255, 255, 128), (147, 211, 203), + (150, 100, 100), (168, 171, 172), (146, 112, 198), (210, 170, 100), + (92, 136, 89), (218, 88, 184), (241, 129, 0), (217, 17, 255), + (124, 74, 181), (70, 70, 70), (255, 228, 255), (154, 208, 0), + (193, 0, 92), (76, 91, 113), (255, 180, 195), (106, 154, 176), + (230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55), + (254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255), + (104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74), + (135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149), + (183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153), + (146, 139, 141), (70, 130, 180), (134, 199, 156), (209, 226, 140), + (96, 36, 108), (96, 96, 96), (64, 170, 64), (152, 251, 152), + (208, 229, 228), (206, 186, 171), (152, 161, 64), (116, 112, 0), + (0, 114, 143), (102, 102, 156), (250, 141, 255)] + } + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file`` + + Returns: + List[dict]: A list of annotation. + """ # noqa: E501 + with get_local_path( + self.ann_file, backend_args=self.backend_args) as local_path: + self.coco = self.COCOAPI(local_path) + # use all classes, cannot use self.metainfo anymore. + self.cat_ids = self.coco.get_cat_ids() + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) + + img_ids = self.coco.get_img_ids() + data_list = [] + total_ann_ids = [] + for img_id in img_ids: + raw_img_info = self.coco.load_imgs([img_id])[0] + raw_img_info['img_id'] = img_id + + ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) + raw_ann_info = self.coco.load_anns(ann_ids) + total_ann_ids.extend(ann_ids) + + parsed_data_info = self.parse_data_info({ + 'raw_ann_info': + raw_ann_info, + 'raw_img_info': + raw_img_info + }) + data_list.append(parsed_data_info) + if self.ANN_ID_UNIQUE: + assert len(set(total_ann_ids)) == len( + total_ann_ids + ), f"Annotation ids in '{self.ann_file}' are not unique!" + + del self.coco + + return data_list diff --git a/seg/datasets/coco_pan_sam.py b/seg/datasets/coco_pan_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..02fba828788eb623ef42cd2e2babc298574ea1c8 --- /dev/null +++ b/seg/datasets/coco_pan_sam.py @@ -0,0 +1,82 @@ +import os.path as osp + +from mmdet.registry import DATASETS +from mmdet.datasets.coco_panoptic import CocoPanopticDataset + + +@DATASETS.register_module() +class CocoPanopticSAMDataset(CocoPanopticDataset): + """Coco SAM dataset, stuff is treated as thing. + """ + + def parse_data_info(self, raw_data_info: dict) -> dict: + """Parse raw annotation to target format. + + Args: + raw_data_info (dict): Raw data information load from ``ann_file``. + + Returns: + dict: Parsed annotation. + """ + img_info = raw_data_info['raw_img_info'] + ann_info = raw_data_info['raw_ann_info'] + # filter out unmatched annotations which have + # same segment_id but belong to other image + ann_info = [ + ann for ann in ann_info if ann['image_id'] == img_info['img_id'] + ] + data_info = {} + + img_path = osp.join(self.data_prefix['img'], img_info['file_name']) + if self.data_prefix.get('seg', None): + seg_map_path = osp.join( + self.data_prefix['seg'], + img_info['file_name'].replace('jpg', 'png')) + else: + seg_map_path = None + data_info['img_path'] = img_path + data_info['img_id'] = img_info['img_id'] + data_info['seg_map_path'] = seg_map_path + data_info['height'] = img_info['height'] + data_info['width'] = img_info['width'] + + if self.return_classes: + data_info['text'] = self.metainfo['thing_classes'] + data_info['stuff_text'] = self.metainfo['stuff_classes'] + data_info['custom_entities'] = True # no important + + instances = [] + segments_info = [] + for ann in ann_info: + instance = {} + x1, y1, w, h = ann['bbox'] + if ann['area'] <= 0 or w < 1 or h < 1: + continue + bbox = [x1, y1, x1 + w, y1 + h] + category_id = ann['category_id'] + contiguous_cat_id = self.cat2label[category_id] + + _ = self.coco.load_cats(ids=category_id)[0]['isthing'] + # always set is_thing to true + is_thing = True + if is_thing: + is_crowd = ann.get('iscrowd', False) + instance['bbox'] = bbox + instance['bbox_label'] = contiguous_cat_id + if not is_crowd: + instance['ignore_flag'] = 0 + else: + instance['ignore_flag'] = 1 + is_thing = False + + segment_info = { + 'id': ann['id'], + 'category': contiguous_cat_id, + 'is_thing': is_thing + } + segments_info.append(segment_info) + if len(instance) > 0 and is_thing: + instances.append(instance) + data_info['instances'] = instances + data_info['segments_info'] = segments_info + return data_info diff --git a/seg/datasets/concat_dataset.py b/seg/datasets/concat_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..954bb4cae870eee43e71463e4fcd1c185f45aaea --- /dev/null +++ b/seg/datasets/concat_dataset.py @@ -0,0 +1,190 @@ +from abc import ABC +import logging +from typing import Sequence, Union, Optional, Tuple + +from mmengine.dataset import ConcatDataset, RepeatDataset, ClassBalancedDataset +from mmengine.logging import print_log +from mmengine.registry import DATASETS +from mmengine.dataset.base_dataset import BaseDataset + +from mmdet.structures import TrackDataSample + +from seg.models.utils import NO_OBJ + + +@DATASETS.register_module() +class ConcatOVDataset(ConcatDataset, ABC): + _fully_initialized: bool = False + + def __init__(self, + datasets: Sequence[Union[BaseDataset, dict]], + lazy_init: bool = False, + data_tag: Optional[Tuple[str]] = None, + ): + for i, dataset in enumerate(datasets): + if isinstance(dataset, dict): + dataset.update(lazy_init=lazy_init) + if 'times' in dataset: + dataset['dataset'].update(lazy_init=lazy_init) + super().__init__(datasets, lazy_init=lazy_init, + ignore_keys=['classes', 'thing_classes', 'stuff_classes', 'palette']) + self.data_tag = data_tag + if self.data_tag is not None: + assert len(self.data_tag) == len(datasets) + + cls_names = [] + for dataset in self.datasets: + if isinstance(dataset, RepeatDataset) or isinstance(dataset, ClassBalancedDataset): + if hasattr(dataset.dataset, 'dataset_name'): + name = dataset.dataset.dataset_name + else: + name = dataset.dataset.__class__.__name__ + else: + if hasattr(dataset, 'dataset_name'): + name = dataset.dataset_name + else: + name = dataset.__class__.__name__ + cls_names.append(name) + + thing_classes = [] + thing_mapper = [] + stuff_classes = [] + stuff_mapper = [] + for idx, dataset in enumerate(self.datasets): + if 'classes' not in dataset.metainfo or (self.data_tag is not None and self.data_tag[idx] in ['sam']): + # class agnostic dataset + _thing_mapper = {} + _stuff_mapper = {} + thing_mapper.append(_thing_mapper) + stuff_mapper.append(_stuff_mapper) + continue + _thing_classes = dataset.metainfo['thing_classes'] \ + if 'thing_classes' in dataset.metainfo else dataset.metainfo['classes'] + _stuff_classes = dataset.metainfo['stuff_classes'] if 'stuff_classes' in dataset.metainfo else [] + _thing_mapper = {} + _stuff_mapper = {} + for idy, cls in enumerate(_thing_classes): + flag = False + cls = cls.replace('_or_', ',') + cls = cls.replace('/', ',') + cls = cls.replace('_', ' ') + cls = cls.lower() + for all_idx, all_cls in enumerate(thing_classes): + if set(cls.split(',')).intersection(set(all_cls.split(','))): + _thing_mapper[idy] = all_idx + flag = True + break + if not flag: + thing_classes.append(cls) + _thing_mapper[idy] = len(thing_classes) - 1 + thing_mapper.append(_thing_mapper) + + for idy, cls in enumerate(_stuff_classes): + flag = False + cls = cls.replace('_or_', ',') + cls = cls.replace('/', ',') + cls = cls.replace('_', ' ') + cls = cls.lower() + for all_idx, all_cls in enumerate(stuff_classes): + if set(cls.split(',')).intersection(set(all_cls.split(','))): + _stuff_mapper[idy] = all_idx + flag = True + break + if not flag: + stuff_classes.append(cls) + _stuff_mapper[idy] = len(stuff_classes) - 1 + stuff_mapper.append(_stuff_mapper) + + cls_name = "" + cnt = 0 + dataset_idx = 0 + classes = [*thing_classes, *stuff_classes] + mapper = [] + meta_cls_names = [] + for _thing_mapper, _stuff_mapper in zip(thing_mapper, stuff_mapper): + if not _thing_mapper and not _stuff_mapper: + # class agnostic dataset + _mapper = dict() + for idx in range(1000): + _mapper[idx] = -1 + else: + _mapper = {**_thing_mapper} + _num_thing = len(_thing_mapper) + for key, value in _stuff_mapper.items(): + assert value < len(stuff_classes) + _mapper[key + _num_thing] = _stuff_mapper[key] + len(thing_classes) + assert len(_mapper) == len(_thing_mapper) + len(_stuff_mapper) + cnt += 1 + cls_name = cls_name + cls_names[dataset_idx] + "_" + meta_cls_names.append(cls_names[dataset_idx]) + _mapper[NO_OBJ] = NO_OBJ + mapper.append(_mapper) + dataset_idx += 1 + if cnt > 1: + cls_name = "Concat_" + cls_name + cls_name = cls_name[:-1] + self.dataset_name = cls_name + + self._metainfo.update({ + 'classes': classes, + 'thing_classes': thing_classes, + 'stuff_classes': stuff_classes, + 'mapper': mapper, + 'dataset_names': meta_cls_names + }) + print_log( + f"------------{self.dataset_name}------------", + logger='current', + level=logging.INFO + ) + + for idx, dataset in enumerate(self.datasets): + dataset_type = cls_names[idx] + if isinstance(dataset, RepeatDataset): + times = dataset.times + else: + times = 1 + print_log( + f"|---dataset#{idx + 1} --> name: {dataset_type}; length: {len(dataset)}; repeat times: {times}", + logger='current', + level=logging.INFO + ) + + print_log( + f"------num_things : {len(thing_classes)}; num_stuff : {len(stuff_classes)}------", + logger='current', + level=logging.INFO + ) + + def get_dataset_source(self, idx: int) -> int: + dataset_idx, _ = self._get_ori_dataset_idx(idx) + return dataset_idx + + def __getitem__(self, idx): + if not self._fully_initialized: + print_log( + 'Please call `full_init` method manually to ' + 'accelerate the speed.', + logger='current', + level=logging.WARNING) + self.full_init() + dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) + results = self.datasets[dataset_idx][sample_idx] + _mapper = self.metainfo['mapper'][dataset_idx] + + data_samples = results['data_samples'] + if isinstance(data_samples, TrackDataSample): + for det_sample in data_samples: + if 'gt_sem_seg' in det_sample: + det_sample.gt_sem_seg.sem_seg.apply_(lambda x: _mapper.__getitem__(x)) + if 'gt_instances' in det_sample: + det_sample.gt_instances.labels.apply_(lambda x: _mapper.__getitem__(x)) + else: + if 'gt_sem_seg' in data_samples: + data_samples.gt_sem_seg.sem_seg.apply_(lambda x: _mapper.__getitem__(x)) + if 'gt_instances' in data_samples: + data_samples.gt_instances.labels.apply_(lambda x: _mapper.__getitem__(x)) + + if self.data_tag is not None: + data_samples.data_tag = self.data_tag[dataset_idx] + return results diff --git a/seg/datasets/davis.py b/seg/datasets/davis.py new file mode 100644 index 0000000000000000000000000000000000000000..0a71108b8add2bec3addb438bb8d5f60998a2618 --- /dev/null +++ b/seg/datasets/davis.py @@ -0,0 +1,195 @@ +import os +from typing import Tuple, List + +import pycocotools.mask as maskUtils + +import mmcv +import numpy as np +from mmdet.registry import DATASETS +from mmdet.datasets.base_video_dataset import BaseVideoDataset +from mmengine import fileio, join_path, scandir, track_parallel_progress, dump, list_from_file, print_log, exists, load +from mmengine.dist import master_only, dist + + +def mask2bbox(mask): + bbox = np.zeros((4,), dtype=np.float32) + x_any = np.any(mask, axis=0) + y_any = np.any(mask, axis=1) + x = np.where(x_any)[0] + y = np.where(y_any)[0] + if len(x) > 0 and len(y) > 0: + bbox = np.array((x[0], y[0], x[-1], y[-1]), dtype=np.float32) + return bbox + + +def video_parser(params): + seq_id, vid_folder, ann_folder = params + images = [] + assert os.path.basename(vid_folder) == os.path.basename(ann_folder) + _tmp_img_id = -1 + imgs_cur = sorted(list(map( + lambda x: str(x), scandir(vid_folder, recursive=False, suffix='.jpg') + ))) + pans_cur = sorted(list(map( + lambda x: str(x), scandir(ann_folder, recursive=False, suffix='.png') + ))) + for img_cur, pan_cur in zip(imgs_cur, pans_cur): + assert img_cur.split('.')[0] == pan_cur.split('.')[0] + _tmp_img_id += 1 + img_id = _tmp_img_id + item_full = os.path.join(vid_folder, img_cur) + inst_map = os.path.join(ann_folder, pan_cur) + img_dict = { + 'img_path': item_full, + 'ann_path': inst_map, + } + assert os.path.exists(img_dict['img_path']) + assert os.path.exists(img_dict['ann_path']) + instances = [] + ann_map = mmcv.imread(img_dict['ann_path'], flag='unchanged').astype(np.uint32) + ann_map = ann_map[..., 0] * 1000000 + ann_map[..., 1] * 1000 + ann_map[..., 2] + img_dict['height'], img_dict['width'] = ann_map.shape + + for pan_seg_id in np.unique(ann_map): + if pan_seg_id == 0: + continue + instance = {} + mask = (ann_map == pan_seg_id).astype(np.uint8) + instance['instance_id'] = pan_seg_id + instance['bbox'] = mask2bbox(mask) + instance['bbox_label'] = 0 + instance['ignore_flag'] = 0 + instance['mask'] = maskUtils.encode(np.asfortranarray(mask)) + instance['mask']['counts'] = instance['mask']['counts'].decode() + instances.append(instance) + img_dict['instances'] = instances + img_dict['video_id'] = seq_id + img_dict['frame_id'] = img_id + img_dict['img_id'] = seq_id * 10000 + img_id + images.append(img_dict) + return { + 'video_id': seq_id, + 'images': images, + 'video_length': len(images) + } + + +@DATASETS.register_module() +class DAVIS(BaseVideoDataset): + METAINFO = { + 'classes': {}, + 'palette': {}, + } + + def __init__(self, dataset_version: str, *args, **kwargs): + self.__class__.__name__ = f'DVAIS_{dataset_version}' + super().__init__(*args, **kwargs) + + @master_only + def build_cache(self, ann_json_path, video_folders, ann_folders) -> None: + vid_ids = range(len(video_folders)) + + data_list = track_parallel_progress( + video_parser, + tasks=list(zip(vid_ids, video_folders, ann_folders)), + nproc=20, + keep_order=False, + ) + data_list = sorted(data_list, key=lambda x: x['video_id']) + dump(data_list, ann_json_path) + + def load_data_list(self) -> List[dict]: + """Load annotations from an annotation file named as ``self.ann_file``. + + Returns: + tuple(list[dict], list): A list of annotation and a list of + valid data indices. + """ + with fileio.get_local_path(self.ann_file) as local_path: + video_folders = list_from_file(local_path, prefix=self.data_prefix['img']) + ann_folders = list_from_file(local_path, prefix=self.data_prefix['ann']) + assert len(video_folders) == len(ann_folders) + print_log(f"#videos : {len(video_folders)} ", logger='current') + + split = os.path.basename(self.ann_file).split('.')[0] + ann_json_path = f"{split}_annotations.json" + ann_json_path = join_path(self.data_root, ann_json_path) + if not exists(ann_json_path): + self.build_cache(ann_json_path, video_folders, ann_folders) + dist.barrier() + raw_data_list = load(ann_json_path) + data_list = [] + for raw_data_info in raw_data_list: + data_info = self.parse_data_info(raw_data_info) + data_list.append(data_info) + vid_len_list = [itm['video_length'] for itm in data_list] + max_vid_len = max(vid_len_list) + min_vid_len = min(vid_len_list) + print_log( + f"Max video len : {max_vid_len}; " + f"Min video len : {min_vid_len}." + , + logger='current', + ) + return data_list + + def parse_data_info(self, raw_data_info: dict) -> dict: + data_info = { + 'video_id': raw_data_info['video_id'], + 'video_length': raw_data_info['video_length'] + } + images = [] + for raw_img_data_info in raw_data_info['images']: + img_data_info = { + 'img_path': raw_img_data_info['img_path'], + 'height': raw_img_data_info['height'], + 'width': raw_img_data_info['width'], + 'video_id': raw_img_data_info['video_id'], + 'frame_id': raw_img_data_info['frame_id'], + 'img_id': raw_img_data_info['img_id'] + } + instances = [] + segments_info = [] + for ann in raw_img_data_info['instances']: + instance = {} + category_id = ann['bbox_label'] + bbox = ann['bbox'] + is_thing = 1 + if is_thing: + instance['bbox'] = bbox + instance['bbox_label'] = category_id + instance['ignore_flag'] = ann['ignore_flag'] + instance['instance_id'] = ann['instance_id'] + + segment_info = { + 'mask': ann['mask'], + 'category': category_id, + 'is_thing': is_thing + } + segments_info.append(segment_info) + if len(instance) > 0 and is_thing: + instances.append(instance) + img_data_info['instances'] = instances + img_data_info['segments_info'] = segments_info + images.append(img_data_info) + data_info['images'] = images + return data_info + + def filter_data(self) -> List[dict]: + """Filter image annotations according to filter_cfg. + + Returns: + list[int]: Filtered results. + """ + if self.test_mode: + return self.data_list + + num_imgs_before_filter = sum([len(info['images']) for info in self.data_list]) + num_imgs_after_filter = num_imgs_before_filter + + new_data_list = self.data_list + + print_log( + 'The number of samples before and after filtering: ' + f'{num_imgs_before_filter} / {num_imgs_after_filter}', 'current') + return new_data_list diff --git a/seg/datasets/pipeliens/formatting.py b/seg/datasets/pipeliens/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..b06befcb8f420f147c05621d685893e683e2978b --- /dev/null +++ b/seg/datasets/pipeliens/formatting.py @@ -0,0 +1,246 @@ +from typing import Optional, Sequence, List + +import torch +import random +import numpy as np +from mmcv.transforms import to_tensor +from mmcv.transforms.base import BaseTransform +from mmdet.datasets.transforms import PackDetInputs +from mmdet.structures.bbox import BaseBoxes +from mmengine.structures import InstanceData, PixelData + +from mmdet.registry import TRANSFORMS +from mmdet.structures import DetDataSample, TrackDataSample + + +@TRANSFORMS.register_module() +class PackVidSegInputs(BaseTransform): + """Pack the inputs data for the multi object tracking and video instance + segmentation. All the information of images are packed to ``inputs``. All + the information except images are packed to ``data_samples``. In order to + get the original annotaiton and meta info, we add `instances` key into meta + keys. + + Args: + meta_keys (Sequence[str]): Meta keys to be collected in + ``data_sample.metainfo``. Defaults to None. + default_meta_keys (tuple): Default meta keys. Defaults to ('img_id', + 'img_path', 'ori_shape', 'img_shape', 'scale_factor', + 'flip', 'flip_direction', 'frame_id', 'is_video_data', + 'video_id', 'video_length', 'instances'). + """ + mapping_table = { + 'gt_bboxes': 'bboxes', + 'gt_bboxes_labels': 'labels', + 'gt_masks': 'masks', + 'gt_instances_ids': 'instances_ids' + } + + def __init__(self, + meta_keys: Optional[dict] = None, + default_meta_keys: tuple = ('img_id', 'img_path', 'ori_shape', + 'img_shape', 'scale_factor', + 'flip', 'flip_direction', + 'frame_id', 'video_id', + 'video_length', + 'ori_video_length', 'instances')): + self.meta_keys = default_meta_keys + if meta_keys is not None: + if isinstance(meta_keys, str): + meta_keys = (meta_keys,) + else: + assert isinstance(meta_keys, tuple), \ + 'meta_keys must be str or tuple' + self.meta_keys += meta_keys + + def transform(self, results: dict) -> dict: + """Method to pack the input data. + Args: + results (dict): Result dict from the data pipeline. + Returns: + dict: + - 'inputs' (dict[Tensor]): The forward data of models. + - 'data_samples' (obj:`TrackDataSample`): The annotation info of + the samples. + """ + packed_results = dict() + packed_results['inputs'] = dict() + + # 1. Pack images + if 'img' in results: + imgs = results['img'] + imgs = np.stack(imgs, axis=0) + imgs = imgs.transpose(0, 3, 1, 2) + packed_results['inputs'] = to_tensor(imgs) + + # 2. Pack InstanceData + if 'gt_ignore_flags' in results: + gt_ignore_flags_list = results['gt_ignore_flags'] + valid_idx_list, ignore_idx_list = [], [] + for gt_ignore_flags in gt_ignore_flags_list: + valid_idx = np.where(gt_ignore_flags == 0)[0] + ignore_idx = np.where(gt_ignore_flags == 1)[0] + valid_idx_list.append(valid_idx) + ignore_idx_list.append(ignore_idx) + + assert 'img_id' in results, "'img_id' must contained in the results " + 'for counting the number of images' + + num_imgs = len(results['img_id']) + instance_data_list = [InstanceData() for _ in range(num_imgs)] + ignore_instance_data_list = [InstanceData() for _ in range(num_imgs)] + + for key in self.mapping_table.keys(): + if key not in results: + continue + if key == 'gt_masks' or (isinstance(results[key], List) and isinstance(results[key][0], BaseBoxes)): + mapped_key = self.mapping_table[key] + gt_masks_list = results[key] + if 'gt_ignore_flags' in results: + for i, gt_mask in enumerate(gt_masks_list): + valid_idx, ignore_idx = valid_idx_list[ + i], ignore_idx_list[i] + instance_data_list[i][mapped_key] = gt_mask[valid_idx] + ignore_instance_data_list[i][mapped_key] = gt_mask[ + ignore_idx] + + else: + for i, gt_mask in enumerate(gt_masks_list): + instance_data_list[i][mapped_key] = gt_mask + + else: + anns_list = results[key] + if 'gt_ignore_flags' in results: + for i, ann in enumerate(anns_list): + valid_idx, ignore_idx = valid_idx_list[ + i], ignore_idx_list[i] + instance_data_list[i][ + self.mapping_table[key]] = to_tensor( + ann[valid_idx]) + ignore_instance_data_list[i][ + self.mapping_table[key]] = to_tensor( + ann[ignore_idx]) + else: + for i, ann in enumerate(anns_list): + instance_data_list[i][ + self.mapping_table[key]] = to_tensor(ann) + + det_data_samples_list = [] + for i in range(num_imgs): + det_data_sample = DetDataSample() + det_data_sample.gt_instances = instance_data_list[i] + det_data_sample.ignored_instances = ignore_instance_data_list[i] + + if 'proposals' in results: + proposals = InstanceData( + bboxes=to_tensor(results['proposals'][i]), + scores=to_tensor(results['proposals_scores'][i])) + det_data_sample.proposals = proposals + + if 'gt_seg_map' in results: + gt_sem_seg_data = dict( + sem_seg=to_tensor(results['gt_seg_map'][i][None, ...].copy())) + gt_sem_seg_data = PixelData(**gt_sem_seg_data) + if 'ignore_index' in results: + metainfo = dict(ignore_index=results['ignore_index'][i]) + gt_sem_seg_data.set_metainfo(metainfo) + det_data_sample.gt_sem_seg = gt_sem_seg_data + + det_data_samples_list.append(det_data_sample) + + # 3. Pack metainfo + for key in self.meta_keys: + if key not in results: + continue + img_metas_list = results[key] + for i, img_meta in enumerate(img_metas_list): + det_data_samples_list[i].set_metainfo({f'{key}': img_meta}) + + track_data_sample = TrackDataSample() + track_data_sample.video_data_samples = det_data_samples_list + if 'key_frame_flags' in results: + key_frame_flags = np.asarray(results['key_frame_flags']) + key_frames_inds = np.where(key_frame_flags)[0].tolist() + ref_frames_inds = np.where(~key_frame_flags)[0].tolist() + track_data_sample.set_metainfo( + dict(key_frames_inds=key_frames_inds)) + track_data_sample.set_metainfo( + dict(ref_frames_inds=ref_frames_inds)) + + packed_results['data_samples'] = track_data_sample + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'meta_keys={self.meta_keys}, ' + repr_str += f'default_meta_keys={self.default_meta_keys})' + return repr_str + + +@TRANSFORMS.register_module() +class PackSAMInputs(PackDetInputs): + mapping_table = { + 'gt_bboxes': 'bboxes', + 'gt_bboxes_labels': 'labels', + 'gt_masks': 'masks', + 'gt_point_coords': 'point_coords', + } + + def transform(self, results: dict) -> dict: + if 'feat' in results: + gt_feats = results['feat'] + results = super().transform(results) + results['data_samples'].gt_feats = gt_feats + return results + else: + return super().transform(results) + + +@TRANSFORMS.register_module() +class GeneratePoint(BaseTransform): + def __init__(self, num_proposals=60, num_mask_tokens=4): + self.num_proposals = num_proposals + self.num_mask_tokens = num_mask_tokens + + def transform(self, results): + data_samples = results['data_samples'] + gt_instances = data_samples.gt_instances + + ori_num_instances = len(gt_instances) + ori_indices = torch.randperm(ori_num_instances) + + if ori_num_instances < self.num_proposals: + repeat_cnt = (self.num_proposals // ori_num_instances) + 1 + ori_indices = ori_indices.repeat(repeat_cnt) + indices = ori_indices[:self.num_proposals] + + masks = gt_instances.masks.to_tensor(torch.bool, 'cpu') + gt_collected = [] + for instance_idx in indices: + mask = masks[instance_idx] + candidate_indices = mask.nonzero() + assert len(candidate_indices) > 0 + selected_index = random.randint(0, len(candidate_indices) - 1) + selected_point = candidate_indices[selected_index].flip(0) + + selected_instances_idx = [] + for instance_to_match_idx in range(len(gt_instances)): + mask_to_match = masks[instance_to_match_idx] + if mask_to_match[tuple(selected_point.flip(0))]: + selected_instances_idx.append(instance_to_match_idx) + assert len(selected_instances_idx) > 0 + if len(selected_instances_idx) > self.num_mask_tokens: + random.shuffle(selected_instances_idx) + selected_instances_idx = selected_instances_idx[:self.num_mask_tokens] + selected_point = torch.cat([selected_point - 3, selected_point + 3], 0) + gt_collected.append({ + 'point_coords': selected_point, + 'instances': selected_instances_idx, + }) + + data_samples.gt_instances_collected = InstanceData( + point_coords=torch.stack([itm['point_coords'] for itm in gt_collected]), + sub_instances=[itm['instances'] for itm in gt_collected], + idx=indices + ) + return results diff --git a/seg/datasets/pipeliens/frame_copy.py b/seg/datasets/pipeliens/frame_copy.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca8bb9d03f6f59dbacd5b300c0facc676bb4908 --- /dev/null +++ b/seg/datasets/pipeliens/frame_copy.py @@ -0,0 +1,57 @@ +import copy + +import numpy as np +from mmcv import BaseTransform +from mmdet.registry import TRANSFORMS + +from seg.models.utils import NO_OBJ + + +@TRANSFORMS.register_module() +class ImageCopy(BaseTransform): + """Copy an image several times to build a video seq. + """ + DIVISOR = 10000 + + def __init__( + self, + num_frames: int = 1, + ) -> None: + assert num_frames > 1 + self.num_frames = num_frames + + def transform(self, results: dict) -> dict: + for key in results: + value = results[key] + results[key] = [] + for _ in range(self.num_frames): + results[key].append(copy.deepcopy(value)) + + num_instances = len(results['gt_bboxes_labels'][0]) + num_frames = len(results['gt_bboxes_labels']) + gt_instance_ids = results['gt_bboxes_labels'][0] * self.DIVISOR + np.arange(num_instances) + 1 + results['gt_instances_ids'] = [copy.deepcopy(gt_instance_ids) for _ in range(num_frames)] + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(num_frames={self.num_frames})' + return repr_str + + +@TRANSFORMS.register_module() +class AddSemSeg(BaseTransform): + """Add dummy semantic segmentation map. + """ + + def __init__(self, ) -> None: + pass + + def transform(self, results: dict) -> dict: + gt_seg = np.zeros(results['img'].shape[:2], dtype=np.int32) + NO_OBJ + results['gt_seg_map'] = gt_seg + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + return repr_str diff --git a/seg/datasets/pipeliens/frame_sampling.py b/seg/datasets/pipeliens/frame_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad111ce860715639cbfefc6c70163ea66d69b0f --- /dev/null +++ b/seg/datasets/pipeliens/frame_sampling.py @@ -0,0 +1,45 @@ +import random +from typing import Dict, List, Optional + +import numpy as np +from mmdet.registry import TRANSFORMS +from mmdet.datasets.transforms import BaseFrameSample + + +@TRANSFORMS.register_module() +class VideoClipSample(BaseFrameSample): + def __init__(self, + num_selected: int = 1, + interval: int = 1, + collect_video_keys: List[str] = ['video_id', 'video_length']): + self.num_selected = num_selected + self.interval = interval + super().__init__(collect_video_keys=collect_video_keys) + + def transform(self, video_infos: dict) -> Optional[Dict[str, List]]: + """Transform the video information. + + Args: + video_infos (dict): The whole video information. + + Returns: + dict: The data information of the sampled frames. + """ + len_with_interval = self.num_selected + (self.num_selected - 1) * (self.interval - 1) + len_video = video_infos['video_length'] + if len_with_interval > len_video: + return None + + first_frame_id = random.sample(range(len_video - len_with_interval + 1), 1)[0] + + sampled_frames_ids = first_frame_id + np.arange(self.num_selected) * self.interval + results = self.prepare_data(video_infos, sampled_frames_ids) + + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'num_selected=({self.num_selected}' + repr_str += f'interval={self.interval}' + repr_str += f'collect_video_keys={self.collect_video_keys})' + return repr_str diff --git a/seg/datasets/pipeliens/loading.py b/seg/datasets/pipeliens/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..31439152bd09f09bd8cbd2f8c5a491cd43211ca5 --- /dev/null +++ b/seg/datasets/pipeliens/loading.py @@ -0,0 +1,388 @@ +from typing import Optional, Tuple, Union + +import mmcv +import mmengine +import numpy as np +import pycocotools.mask as maskUtils +import torch + +from mmcv.transforms.base import BaseTransform +from mmdet.registry import TRANSFORMS +from mmdet.datasets.transforms import LoadAnnotations as MMDET_LoadAnnotations +from mmdet.structures.bbox import autocast_box_type +from mmdet.structures.mask import BitmapMasks +from mmdet.datasets.transforms import LoadPanopticAnnotations +from mmengine.fileio import get + +from seg.models.utils import NO_OBJ + + +@TRANSFORMS.register_module() +class LoadPanopticAnnotationsHB(LoadPanopticAnnotations): + def _load_masks_and_semantic_segs(self, results: dict) -> None: + """Private function to load mask and semantic segmentation annotations. + + In gt_semantic_seg, the foreground label is from ``0`` to + ``num_things - 1``, the background label is from ``num_things`` to + ``num_things + num_stuff - 1``, 255 means the ignored label (``VOID``). + + Args: + results (dict): Result dict from :obj:``mmdet.CustomDataset``. + """ + # seg_map_path is None, when inference on the dataset without gts. + if results.get('seg_map_path', None) is None: + return + + img_bytes = get( + results['seg_map_path'], backend_args=self.backend_args) + pan_png = mmcv.imfrombytes( + img_bytes, flag='color', channel_order='rgb').squeeze() + pan_png = self.rgb2id(pan_png) + + gt_masks = [] + gt_seg = np.zeros_like(pan_png).astype(np.int32) + NO_OBJ # 255 as ignore + + for segment_info in results['segments_info']: + mask = (pan_png == segment_info['id']) + gt_seg = np.where(mask, segment_info['category'], gt_seg) + + # The legal thing masks + if segment_info.get('is_thing'): + gt_masks.append(mask.astype(np.uint8)) + + if self.with_mask: + h, w = results['ori_shape'] + gt_masks = BitmapMasks(gt_masks, h, w) + results['gt_masks'] = gt_masks + + if self.with_seg: + results['gt_seg_map'] = gt_seg + + +@TRANSFORMS.register_module() +class LoadVideoSegAnnotations(LoadPanopticAnnotations): + + def __init__( + self, + **kwargs + ) -> None: + super().__init__(**kwargs) + + def _load_instances_ids(self, results: dict) -> None: + """Private function to load instances id annotations. + + Args: + results (dict): Result dict from :obj :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict containing instances id annotations. + """ + gt_instances_ids = [] + for instance in results['instances']: + gt_instances_ids.append(instance['instance_id']) + results['gt_instances_ids'] = np.array( + gt_instances_ids, dtype=np.int32) + + def _load_masks_and_semantic_segs(self, results: dict) -> None: + h, w = results['ori_shape'] + gt_masks = [] + gt_seg = np.zeros((h, w), dtype=np.int32) + NO_OBJ + + for segment_info in results['segments_info']: + mask = maskUtils.decode(segment_info['mask']) + gt_seg = np.where(mask, segment_info['category'], gt_seg) + + # The legal thing masks + if segment_info.get('is_thing'): + gt_masks.append(mask.astype(np.uint8)) + + if self.with_mask: + h, w = results['ori_shape'] + gt_masks = BitmapMasks(gt_masks, h, w) + results['gt_masks'] = gt_masks + + if self.with_seg: + results['gt_seg_map'] = gt_seg + + def transform(self, results: dict) -> dict: + """Function to load multiple types panoptic annotations. + + Args: + results (dict): Result dict from :obj:``mmdet.CustomDataset``. + + Returns: + dict: The dict contains loaded bounding box, label, mask and + semantic segmentation annotations. + """ + + super().transform(results) + self._load_instances_ids(results) + return results + + +@TRANSFORMS.register_module() +class LoadJSONFromFile(BaseTransform): + """Load an json from file. + + Required Keys: + + - info_path + + Modified Keys: + + Args: + backend_args (dict, optional): Instantiates the corresponding file + backend. It may contain `backend` key to specify the file + backend. If it contains, the file backend corresponding to this + value will be used and initialized with the remaining values, + otherwise the corresponding file backend will be selected + based on the prefix of the file path. Defaults to None. + New in version 2.0.0rc4. + """ + + def __init__(self, backend_args: Optional[dict] = None) -> None: + self.backend_args: Optional[dict] = None + if backend_args is not None: + self.backend_args = backend_args.copy() + + def transform(self, results: dict) -> Optional[dict]: + """Functions to load image. + + Args: + results (dict): Result dict from + :class:`mmengine.dataset.BaseDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + filename = results['info_path'] + data_info = mmengine.load(filename, backend_args=self.backend_args) + + results['height'] = data_info['image']['height'] + results['width'] = data_info['image']['width'] + + # The code here are similar to `parse_data_info` in coco + instances = [] + for ann in sorted(data_info['annotations'], key=lambda x: -x['area']): + instance = {} + + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + inter_w = max(0, min(x1 + w, results['width']) - max(x1, 0)) + inter_h = max(0, min(y1 + h, results['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if ann['area'] <= 0 or w < 1 or h < 1: + continue + bbox = [x1, y1, x1 + w, y1 + h] + + instance['ignore_flag'] = 0 + instance['bbox'] = bbox + instance['bbox_label'] = 0 + + if ann.get('segmentation', None): + instance['mask'] = ann['segmentation'] + + if ann.get('point_coords', None): + instance['point_coords'] = ann['point_coords'] + + instances.append(instance) + + results['instances'] = instances + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'backend_args={self.backend_args})') + + return repr_str + + +@TRANSFORMS.register_module() +class LoadAnnotationsSAM(MMDET_LoadAnnotations): + + def __init__(self, *args, with_point_coords=False, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.with_point_coords = with_point_coords + + def _load_point_coords(self, results: dict) -> None: + assert self.with_point_coords + gt_point_coords = [] + for instance in results.get('instances', []): + gt_point_coords.append(instance['point_coords']) + results['gt_point_coords'] = np.array(gt_point_coords, dtype=np.float32) + + def transform(self, results: dict) -> Optional[dict]: + super().transform(results) + if self.with_point_coords: + self._load_point_coords(results) + return results + + +@TRANSFORMS.register_module() +class FilterAnnotationsHB(BaseTransform): + """Filter invalid annotations. + + Required Keys: + + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_ignore_flags (bool) (optional) + + Modified Keys: + + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_masks (optional) + - gt_ignore_flags (optional) + + Args: + min_gt_bbox_wh (tuple[float]): Minimum width and height of ground truth + boxes. Default: (1., 1.) + min_gt_mask_area (int): Minimum foreground area of ground truth masks. + Default: 1 + by_box (bool): Filter instances with bounding boxes not meeting the + min_gt_bbox_wh threshold. Default: True + by_mask (bool): Filter instances with masks not meeting + min_gt_mask_area threshold. Default: False + keep_empty (bool): Whether to return None when it + becomes an empty bbox after filtering. Defaults to True. + """ + + def __init__(self, + min_gt_bbox_wh: Tuple[int, int] = (1, 1), + min_gt_mask_area: int = 1, + by_box: bool = True, + by_mask: bool = False) -> None: + assert by_box or by_mask + self.min_gt_bbox_wh = min_gt_bbox_wh + self.min_gt_mask_area = min_gt_mask_area + self.by_box = by_box + self.by_mask = by_mask + + @autocast_box_type() + def transform(self, results: dict) -> Union[dict, None]: + """Transform function to filter annotations. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + assert 'gt_bboxes' in results + gt_bboxes = results['gt_bboxes'] + if gt_bboxes.shape[0] == 0: + return None + + tests = [] + if self.by_box: + tests.append( + ((gt_bboxes.widths > self.min_gt_bbox_wh[0]) & + (gt_bboxes.heights > self.min_gt_bbox_wh[1])).numpy()) + if self.by_mask: + assert 'gt_masks' in results + gt_masks = results['gt_masks'] + tests.append(gt_masks.areas >= self.min_gt_mask_area) + + keep = tests[0] + for t in tests[1:]: + keep = keep & t + + results['gt_ignore_flags'] = np.logical_or(results['gt_ignore_flags'], np.logical_not(keep)) + if results['gt_ignore_flags'].all(): + return None + return results + + def __repr__(self): + return self.__class__.__name__ + + +@TRANSFORMS.register_module() +class GTNMS(BaseTransform): + + def __init__(self, + by_box: bool = True, + by_mask: bool = False + ) -> None: + assert by_box or by_mask and not (by_box and by_mask) + self.by_box = by_box + self.by_mask = by_mask + + @autocast_box_type() + def transform(self, results: dict) -> Union[dict, None]: + """Transform function to filter annotations. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + gt_ignore_flags = results['gt_ignore_flags'] + if self.by_box: + raise NotImplementedError + if self.by_mask: + assert 'gt_masks' in results + gt_masks = results['gt_masks'].masks + tot_mask = np.zeros_like(gt_masks[0], dtype=np.uint8) + for idx, mask in enumerate(gt_masks): + if gt_ignore_flags[idx]: + continue + overlapping = mask * tot_mask + ratio = overlapping.sum() / sum(mask).sum() + if ratio > 0.8: + # ignore with overlapping + gt_ignore_flags[idx] = True + continue + tot_mask = (tot_mask + mask).clip(max=1) + + results['gt_ignore_flags'] = gt_ignore_flags + return results + + def __repr__(self): + return self.__class__.__name__ + + +@TRANSFORMS.register_module() +class LoadFeatFromFile(BaseTransform): + + def __init__(self, model_name='vit_h'): + self.cache_suffix = f'_{model_name}_cache.pth' + + def transform(self, results: dict) -> Optional[dict]: + img_path = results['img_path'] + feat_path = img_path.replace('.jpg', self.cache_suffix) + assert mmengine.exists(feat_path) + feat = torch.load(feat_path) + results['feat'] = feat + return results + + def __repr__(self): + repr_str = f'{self.__class__.__name__}' + + return repr_str + + +@TRANSFORMS.register_module() +class ResizeOri(BaseTransform): + + def __init__( + self, + backend: str = 'cv2', + interpolation='bilinear' + ): + self.backend = backend + self.interpolation = interpolation + + def transform(self, results: dict) -> Optional[dict]: + results['ori_shape'] = results['img_shape'] + results['scale_factor'] = (1., 1.) + return results + + def __repr__(self): + repr_str = f'{self.__class__.__name__}' + return repr_str diff --git a/seg/datasets/pipeliens/transforms.py b/seg/datasets/pipeliens/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..52a8ad8bd374d4ace90ec62e36d14545007a5ef8 --- /dev/null +++ b/seg/datasets/pipeliens/transforms.py @@ -0,0 +1,135 @@ +import numpy as np + +from mmdet.registry import TRANSFORMS +from mmdet.structures.bbox import autocast_box_type + +from mmcv.image.geometric import _scale_size +from mmcv.transforms import Resize as MMCV_Resize +from mmdet.datasets.transforms import Resize as MMDET_Resize + + +@TRANSFORMS.register_module() +class ResizeImage(MMCV_Resize): + """Resize images only. + + This transform resizes the input image according to ``scale`` or + ``scale_factor``. Bboxes, masks, and seg map are then resized + with the same scale factor. + if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to + resize. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_masks (BitmapMasks | PolygonMasks) (optional) + - gt_seg_map (np.uint8) (optional) + + Modified Keys: + + - img + - img_shape + - gt_bboxes + - gt_masks + - gt_seg_map + + + Added Keys: + + - scale + - scale_factor + - keep_ratio + - homography_matrix + + Args: + scale (int or tuple): Images scales for resizing. Defaults to None + scale_factor (float or tuple[float]): Scale factors for resizing. + Defaults to None. + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. Defaults to False. + clip_object_border (bool): Whether to clip the objects + outside the border of the image. In some dataset like MOT17, the gt + bboxes are allowed to cross the border of images. Therefore, we + don't need to clip the gt bboxes in these cases. Defaults to True. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. Defaults + to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. Defaults + to 'bilinear'. + """ + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Transform function to resize images, bounding boxes and semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map', + 'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys + are updated in result dict. + """ + if self.scale: + results['scale'] = self.scale + else: + img_shape = results['img'].shape[:2] + results['scale'] = _scale_size(img_shape[::-1], self.scale_factor) + self._resize_img(results) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(scale={self.scale}, ' + repr_str += f'scale_factor={self.scale_factor}, ' + repr_str += f'keep_ratio={self.keep_ratio}, ' + repr_str += f'clip_object_border={self.clip_object_border}), ' + repr_str += f'backend={self.backend}), ' + repr_str += f'interpolation={self.interpolation})' + return repr_str + + +@TRANSFORMS.register_module() +class ResizeSAM(MMDET_Resize): + def _resize_point_coords(self, results: dict) -> None: + if results.get('gt_point_coords', None) is not None: + results['gt_point_coords'] = results['gt_point_coords'] * results['scale_factor'] + results['gt_point_coords'][..., 0] = np.clip(results['gt_point_coords'][..., 0], 0, results['img_shape'][1]) + results['gt_point_coords'][..., 1] = np.clip(results['gt_point_coords'][..., 1], 0, results['img_shape'][0]) + + @autocast_box_type() + def transform(self, results: dict) -> dict: + """Transform function to resize images, bounding boxes and semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map', + 'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys + are updated in result dict. + """ + if self.scale: + results['scale'] = self.scale + else: + img_shape = results['img'].shape[:2] + results['scale'] = _scale_size(img_shape[::-1], self.scale_factor) + self._resize_img(results) + self._resize_bboxes(results) + self._resize_masks(results) + self._resize_seg(results) + self._resize_point_coords(results) + self._record_homography_matrix(results) + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(scale={self.scale}, ' + repr_str += f'scale_factor={self.scale_factor}, ' + repr_str += f'keep_ratio={self.keep_ratio}, ' + repr_str += f'clip_object_border={self.clip_object_border}), ' + repr_str += f'backend={self.backend}), ' + repr_str += f'interpolation={self.interpolation})' + return repr_str diff --git a/seg/datasets/samplers/batch_sampler.py b/seg/datasets/samplers/batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..00d33c11caf498eb262cff1ea1d4d3383a6c3476 --- /dev/null +++ b/seg/datasets/samplers/batch_sampler.py @@ -0,0 +1,154 @@ +from typing import Sequence + +import torch +import torch.distributed as torch_dist +from mmengine.dist import get_dist_info, get_default_group, get_comm_device +from torch._C._distributed_c10d import ReduceOp +from torch.utils.data import Sampler, BatchSampler + +from mmdet.datasets.samplers.batch_sampler import AspectRatioBatchSampler +from mmdet.registry import DATA_SAMPLERS + + +@DATA_SAMPLERS.register_module() +class VideoSegAspectRatioBatchSampler(AspectRatioBatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio (< 1 or. + + >= 1) into a same batch. + + Args: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + """ + + def __iter__(self) -> Sequence[int]: + for idx in self.sampler: + # hard code to solve TrackImgSampler + video_idx = idx + # video_idx + data_info = self.sampler.dataset.get_data_info(video_idx) + # data_info {video_id, images, video_length} + if 'images' in data_info: + img_data_info = data_info['images'][0] + else: + img_data_info = data_info + width, height = img_data_info['width'], img_data_info['height'] + bucket_id = 0 if width < height else 1 + bucket = self._aspect_ratio_buckets[bucket_id] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + + # yield the rest data and reset the bucket + left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[ + 1] + self._aspect_ratio_buckets = [[] for _ in range(2)] + while len(left_data) > 0: + if len(left_data) <= self.batch_size: + if not self.drop_last: + yield left_data[:] + left_data = [] + else: + yield left_data[:self.batch_size] + left_data = left_data[self.batch_size:] + + +@DATA_SAMPLERS.register_module() +class MultiDataAspectRatioBatchSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio (< 1 or. + + >= 1) into a same batch for multi-source datasets. + + Args: + sampler (Sampler): Base sampler. + batch_size (Sequence(int)): Size of mini-batch for multi-source + datasets. + num_datasets(int): Number of multi-source datasets. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + """ + + def __init__(self, + sampler: Sampler, + batch_size: Sequence[int], + num_datasets: int, + drop_last: bool = True) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + self.sampler = sampler + if isinstance(batch_size, int): + self.batch_size = [batch_size] * num_datasets + else: + self.batch_size = batch_size + self.num_datasets = num_datasets + self.drop_last = drop_last + # two groups for w < h and w >= h for each dataset --> 2 * num_datasets + self._buckets = [[] for _ in range(2 * self.num_datasets)] + + def __iter__(self) -> Sequence[int]: + num_batch = torch.tensor(len(self), device='cpu') + rank, world_size = get_dist_info() + if world_size > 1: + group = get_default_group() + backend_device = get_comm_device(group) + num_batch = num_batch.to(device=backend_device) + torch_dist.all_reduce(num_batch, op=ReduceOp.MIN, group=group) + num_batch = num_batch.to('cpu').item() + + for idx in self.sampler: + data_info = self.sampler.dataset.get_data_info(idx) + width, height = data_info.get('width', 0), data_info.get('height', 0) + dataset_source_idx = self.sampler.dataset.get_dataset_source(idx) + aspect_ratio_bucket_id = 0 if width < height else 1 + bucket_id = dataset_source_idx * 2 + aspect_ratio_bucket_id + bucket = self._buckets[bucket_id] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size[dataset_source_idx]: + yield bucket[:] + num_batch -= 1 + if num_batch <= 0: + return + del bucket[:] + + # yield the rest data and reset the bucket + for i in range(self.num_datasets): + left_data = self._buckets[i * 2 + 0] + self._buckets[i * 2 + 1] + while len(left_data) > 0: + if len(left_data) < self.batch_size[i]: + if not self.drop_last: + yield left_data[:] + num_batch -= 1 + if num_batch <= 0: + return + left_data = [] + else: + yield left_data[:self.batch_size[i]] + num_batch -= 1 + if num_batch <= 0: + return + left_data = left_data[self.batch_size[i]:] + + self._buckets = [[] for _ in range(2 * self.num_datasets)] + + def __len__(self) -> int: + sizes = [0 for _ in range(self.num_datasets)] + for idx in self.sampler: + dataset_source_idx = self.sampler.dataset.get_dataset_source(idx) + sizes[dataset_source_idx] += 1 + + if self.drop_last: + lens = 0 + for i in range(self.num_datasets): + lens += sizes[i] // self.batch_size[i] + return lens + else: + lens = 0 + for i in range(self.num_datasets): + lens += (sizes[i] + self.batch_size[i] - 1) // self.batch_size[i] + return lens diff --git a/seg/datasets/samplers/multi_dataset_sampler.py b/seg/datasets/samplers/multi_dataset_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c28fe6e719c610627f56663ae883af2d3d9feb65 --- /dev/null +++ b/seg/datasets/samplers/multi_dataset_sampler.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Iterator, Optional, Sequence, Sized + +import torch +from mmengine.dist import get_dist_info, sync_random_seed +from mmengine.registry import DATA_SAMPLERS +from torch.utils.data import Sampler + + +@DATA_SAMPLERS.register_module() +class MultiDataSampler(Sampler): + """The default data sampler for both distributed and non-distributed + environment. + + It has several differences from the PyTorch ``DistributedSampler`` as + below: + + 1. This sampler supports non-distributed environment. + + 2. The round up behaviors are a little different. + + - If ``round_up=True``, this sampler will add extra samples to make the + number of samples is evenly divisible by the world size. And + this behavior is the same as the ``DistributedSampler`` with + ``drop_last=False``. + - If ``round_up=False``, this sampler won't remove or add any samples + while the ``DistributedSampler`` with ``drop_last=True`` will remove + tail samples. + + Args: + dataset (Sized): The dataset. + dataset_ratio (Sequence(int)) The ratios of different datasets. + seed (int, optional): Random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Defaults to None. + round_up (bool): Whether to add extra samples to make the number of + samples evenly divisible by the world size. Defaults to True. + """ + + def __init__(self, + dataset: Sized, + dataset_ratio: Sequence[int], + seed: Optional[int] = None, + round_up: bool = True) -> None: + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.dataset_ratio = dataset_ratio + + if seed is None: + seed = sync_random_seed() + self.seed = seed + self.epoch = 0 + self.round_up = round_up + + if self.round_up: + self.num_samples = math.ceil(len(self.dataset) / world_size) + self.total_size = self.num_samples * self.world_size + else: + self.num_samples = math.ceil( + (len(self.dataset) - rank) / world_size) + self.total_size = len(self.dataset) + + self.sizes = [len(dataset) for dataset in self.dataset.datasets] + + dataset_weight = [ + torch.ones(s) * max(self.sizes) / s * r / sum(self.dataset_ratio) + for i, (r, s) in enumerate(zip(self.dataset_ratio, self.sizes)) + ] + self.weights = torch.cat(dataset_weight) + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + + indices = torch.multinomial( + self.weights, len(self.weights), generator=g, + replacement=True).tolist() + + # add extra samples to make it evenly divisible + if self.round_up: + indices = ( + indices * + int(self.total_size / len(indices) + 1))[:self.total_size] + + # subsample + indices = indices[self.rank:self.total_size:self.world_size] + + return iter(indices) + + def __len__(self) -> int: + """The number of samples in this rank.""" + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + """Sets the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas use a different + random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch diff --git a/seg/datasets/vipseg.py b/seg/datasets/vipseg.py new file mode 100644 index 0000000000000000000000000000000000000000..a87d87c6a8709843cc77b66cf377f2520cb70c1e --- /dev/null +++ b/seg/datasets/vipseg.py @@ -0,0 +1,239 @@ +import logging +import os +from typing import List + +import numpy as np +import pycocotools.mask as maskUtils + +import mmcv +from mmengine import print_log, list_from_file, scandir, track_parallel_progress +from mmengine.dist import master_only, dist +from mmengine.fileio import join_path, exists, load, dump + +from mmdet.datasets import BaseVideoDataset +from mmdet.registry import DATASETS + +from seg.models.utils import INSTANCE_OFFSET_HB + +from ext.class_names.VIPSeg import CLASSES_THING, CLASSES_STUFF, COCO_CLASSES, COCO_THINGS, COCO_STUFF, PALETTE + +NO_OBJ = 0 +NO_OBJ_HB = 255 +NO_OBJ_BUG = (200,) +DIVISOR_PAN = 100 +NUM_THING = 58 +NUM_STUFF = 66 + + +def to_coco(pan_map, divisor=INSTANCE_OFFSET_HB): + pan_new = - np.ones_like(pan_map) + vip2hb_thing = {itm['id'] + 1: idx for idx, itm in enumerate(CLASSES_THING)} + assert len(vip2hb_thing) == NUM_THING + vip2hb_stuff = {itm['id'] + 1: idx for idx, itm in enumerate(CLASSES_STUFF)} + assert len(vip2hb_stuff) == NUM_STUFF + for idx in np.unique(pan_map): + # 200 is a bug in vipseg dataset. + # Please refer to https://github.com/VIPSeg-Dataset/VIPSeg-Dataset/issues/1 + if idx == NO_OBJ or idx in NO_OBJ_BUG: + pan_new[pan_map == idx] = NO_OBJ_HB * divisor + elif idx > 128: + cls_id = idx // DIVISOR_PAN + cls_new_id = vip2hb_thing[cls_id] + inst_id = idx % DIVISOR_PAN + pan_new[pan_map == idx] = cls_new_id * divisor + inst_id + 1 + else: + cls_new_id = vip2hb_stuff[idx] + cls_new_id += NUM_THING + pan_new[pan_map == idx] = cls_new_id * divisor + assert -1 not in np.unique(pan_new) + return pan_new + + +def mask2bbox(mask): + bbox = np.zeros((4,), dtype=np.float32) + x_any = np.any(mask, axis=0) + y_any = np.any(mask, axis=1) + x = np.where(x_any)[0] + y = np.where(y_any)[0] + if len(x) > 0 and len(y) > 0: + bbox = np.array((x[0], y[0], x[-1], y[-1]), dtype=np.float32) + return bbox + + +def video_parser(params): + seq_id, vid_folder, ann_folder = params + images = [] + assert os.path.basename(vid_folder) == os.path.basename(ann_folder) + _tmp_img_id = -1 + imgs_cur = sorted(list(map( + lambda x: str(x), scandir(vid_folder, recursive=False, suffix='.jpg') + ))) + pans_cur = sorted(list(map( + lambda x: str(x), scandir(ann_folder, recursive=False, suffix='.png') + ))) + for img_cur, pan_cur in zip(imgs_cur, pans_cur): + assert img_cur.split('.')[0] == pan_cur.split('.')[0] + _tmp_img_id += 1 + img_id = _tmp_img_id + item_full = os.path.join(vid_folder, img_cur) + inst_map = os.path.join(ann_folder, pan_cur) + img_dict = { + 'img_path': item_full, + 'ann_path': inst_map, + } + assert os.path.exists(img_dict['img_path']) + assert os.path.exists(img_dict['ann_path']) + instances = [] + ann_map = mmcv.imread(img_dict['ann_path'], flag='unchanged').astype(np.uint32) + img_dict['height'], img_dict['width'] = ann_map.shape + pan_map = to_coco(ann_map, INSTANCE_OFFSET_HB) + + for pan_seg_id in np.unique(pan_map): + label = pan_seg_id // INSTANCE_OFFSET_HB + if label == NO_OBJ_HB: + continue + instance = {} + mask = (pan_map == pan_seg_id).astype(np.uint8) + instance['instance_id'] = pan_seg_id + instance['bbox'] = mask2bbox(mask) + instance['bbox_label'] = label + instance['ignore_flag'] = 0 + instance['mask'] = maskUtils.encode(np.asfortranarray(mask)) + instance['mask']['counts'] = instance['mask']['counts'].decode() + instances.append(instance) + img_dict['instances'] = instances + img_dict['video_id'] = seq_id + img_dict['frame_id'] = img_id + img_dict['img_id'] = seq_id * 10000 + img_id + images.append(img_dict) + return { + 'video_id': seq_id, + 'images': images, + 'video_length': len(images) + } + + +@DATASETS.register_module() +class VIPSegDataset(BaseVideoDataset): + METAINFO = { + 'classes': COCO_CLASSES, + 'thing_classes': COCO_THINGS, + 'stuff_classes': COCO_STUFF, + 'palette': PALETTE, + } + + def __init__( + self, + *args, + img_map_suffix: str = '.jpg', + seg_map_suffix: str = '.png', + **kwargs + ): + self.img_map_suffix = img_map_suffix + self.seg_map_suffix = seg_map_suffix + super().__init__(*args, **kwargs) + + @master_only + def build_cache(self, ann_json_path, video_folders, ann_folders) -> None: + vid_ids = range(len(video_folders)) + + data_list = track_parallel_progress( + video_parser, + tasks=list(zip(vid_ids, video_folders, ann_folders)), + nproc=20, + keep_order=False, + ) + data_list = sorted(data_list, key=lambda x: x['video_id']) + dump(data_list, ann_json_path) + + def load_data_list(self) -> List[dict]: + video_folders = list_from_file(self.ann_file, prefix=self.data_prefix['img']) + ann_folders = list_from_file(self.ann_file, prefix=self.data_prefix['seg']) + assert len(video_folders) == len(ann_folders) + print_log(f"#videos : {len(video_folders)} ", + logger='current', + level=logging.INFO) + + split = os.path.basename(self.ann_file).split('.')[0] + ann_json_path = f"{split}_annotations.json" + ann_json_path = join_path(self.data_root, ann_json_path) + if not exists(ann_json_path): + self.build_cache(ann_json_path, video_folders, ann_folders) + dist.barrier() + raw_data_list = load(ann_json_path) + data_list = [] + for raw_data_info in raw_data_list: + data_info = self.parse_data_info(raw_data_info) + data_list.append(data_info) + vid_len_list = [itm['video_length'] for itm in data_list] + max_vid_len = max(vid_len_list) + min_vid_len = min(vid_len_list) + print_log( + f"Max video len : {max_vid_len}; " + f"Min video len : {min_vid_len}." + , + logger='current', + level=logging.INFO + ) + return data_list + + def parse_data_info(self, raw_data_info: dict) -> dict: + data_info = { + 'video_id': raw_data_info['video_id'], + 'video_length': raw_data_info['video_length'] + } + images = [] + for raw_img_data_info in raw_data_info['images']: + img_data_info = { + 'img_path': raw_img_data_info['img_path'], + 'height': raw_img_data_info['height'], + 'width': raw_img_data_info['width'], + 'video_id': raw_img_data_info['video_id'], + 'frame_id': raw_img_data_info['frame_id'], + 'img_id': raw_img_data_info['img_id'] + } + instances = [] + segments_info = [] + for ann in raw_img_data_info['instances']: + instance = {} + category_id = ann['bbox_label'] + bbox = ann['bbox'] + is_thing = category_id < NUM_THING + if is_thing: + instance['bbox'] = bbox + instance['bbox_label'] = category_id + instance['ignore_flag'] = ann['ignore_flag'] + instance['instance_id'] = ann['instance_id'] + + segment_info = { + 'mask': ann['mask'], + 'category': category_id, + 'is_thing': is_thing + } + segments_info.append(segment_info) + if len(instance) > 0 and is_thing: + instances.append(instance) + img_data_info['instances'] = instances + img_data_info['segments_info'] = segments_info + images.append(img_data_info) + data_info['images'] = images + return data_info + + def filter_data(self) -> List[dict]: + """Filter image annotations according to filter_cfg. + + Returns: + list[int]: Filtered results. + """ + if self.test_mode: + return self.data_list + + num_imgs_before_filter = sum([len(info['images']) for info in self.data_list]) + num_imgs_after_filter = num_imgs_before_filter + + new_data_list = self.data_list + + print_log( + 'The number of samples before and after filtering: ' + f'{num_imgs_before_filter} / {num_imgs_after_filter}', 'current') + return new_data_list diff --git a/seg/datasets/youtube_vis_dataset.py b/seg/datasets/youtube_vis_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b9bd0fb6342cc0da14112d75a608654fd753cba0 --- /dev/null +++ b/seg/datasets/youtube_vis_dataset.py @@ -0,0 +1,60 @@ +from mmdet.registry import DATASETS +from mmdet.datasets.base_video_dataset import BaseVideoDataset + + +@DATASETS.register_module() +class YouTubeVISDatasetV2(BaseVideoDataset): + """YouTube VIS dataset for video instance segmentation. + + Args: + dataset_version (str): Select dataset year version. + """ + + def __init__(self, dataset_version: str, *args, **kwargs): + self.set_dataset_classes(dataset_version) + self.dataset_name = f'YouTubeVISDataset_{dataset_version}' + super().__init__(*args, **kwargs) + + @classmethod + def set_dataset_classes(cls, dataset_version: str) -> None: + """Pass the category of the corresponding year to metainfo. + + Args: + dataset_version (str): Select dataset year version. + """ + classes_2019_version = ('person', 'giant_panda', 'lizard', 'parrot', + 'skateboard', 'sedan', 'ape', 'dog', 'snake', + 'monkey', 'hand', 'rabbit', 'duck', 'cat', + 'cow', 'fish', 'train', 'horse', 'turtle', + 'bear', 'motorbike', 'giraffe', 'leopard', + 'fox', 'deer', 'owl', 'surfboard', 'airplane', + 'truck', 'zebra', 'tiger', 'elephant', + 'snowboard', 'boat', 'shark', 'mouse', 'frog', + 'eagle', 'earless_seal', 'tennis_racket') + + classes_2021_version = ('airplane', 'bear', 'bird', 'boat', 'car', + 'cat', 'cow', 'deer', 'dog', 'duck', + 'earless_seal', 'elephant', 'fish', + 'flying_disc', 'fox', 'frog', 'giant_panda', + 'giraffe', 'horse', 'leopard', 'lizard', + 'monkey', 'motorbike', 'mouse', 'parrot', + 'person', 'rabbit', 'shark', 'skateboard', + 'snake', 'snowboard', 'squirrel', 'surfboard', + 'tennis_racket', 'tiger', 'train', 'truck', + 'turtle', 'whale', 'zebra') + + classes_ovis_version = ('Person', 'Bird', 'Cat', 'Dog', 'Horse', + 'Sheep', 'Cow', 'Elephant', 'Bear', 'Zebra', + 'Giraffe', 'Poultry', 'Giant_panda', 'Lizard', 'Parrot', + 'Monkey', 'Rabbit', 'Tiger', 'Fish', 'Turtle', + 'Bicycle', 'Motorcycle', 'Airplane', 'Boat', 'Vehical') + + if dataset_version == '2019': + cls.METAINFO = dict(classes=classes_2019_version) + elif dataset_version == '2021': + cls.METAINFO = dict(classes=classes_2021_version) + elif dataset_version == 'ovis': + cls.METAINFO = dict(classes=classes_ovis_version) + else: + raise NotImplementedError('Not supported YouTubeVIS dataset' + f'version: {dataset_version}') diff --git a/seg/evaluation/metrics/cityscapes_panoptic_metric.py b/seg/evaluation/metrics/cityscapes_panoptic_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a505cf0c4c0fc0067a07de38822a2d7f5f5657 --- /dev/null +++ b/seg/evaluation/metrics/cityscapes_panoptic_metric.py @@ -0,0 +1,618 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import itertools +import os.path +import os.path as osp +import tempfile +from typing import Dict, Optional, Sequence, Tuple, Union + +import mmcv +import numpy as np +from mmengine.evaluator import BaseMetric +from mmengine.fileio import dump, get_local_path, load +from mmengine.logging import MMLogger, print_log +from terminaltables import AsciiTable + +from mmdet.datasets.api_wrappers import COCOPanoptic +from mmdet.registry import METRICS +from mmdet.evaluation.functional import (INSTANCE_OFFSET, pq_compute_multi_core, pq_compute_single_core) + +try: + import panopticapi + from panopticapi.evaluation import VOID, PQStat + from panopticapi.utils import id2rgb, rgb2id +except ImportError: + panopticapi = None + id2rgb = None + rgb2id = None + VOID = None + PQStat = None + + +@METRICS.register_module() +class CityscapesPanopticMetric(BaseMetric): + """COCO panoptic segmentation evaluation metric. + + Evaluate PQ, SQ RQ for panoptic segmentation tasks. Please refer to + https://cocodataset.org/#panoptic-eval for more details. + + Args: + ann_file (str, optional): Path to the coco format annotation file. + If not specified, ground truth annotations from the dataset will + be converted to coco format. Defaults to None. + seg_prefix (str, optional): Path to the directory which contains the + coco panoptic segmentation mask. It should be specified when + evaluate. Defaults to None. + classwise (bool): Whether to evaluate the metric class-wise. + Defaults to False. + outfile_prefix (str, optional): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. + It should be specified when format_only is True. Defaults to None. + format_only (bool): Format the output results without perform + evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + nproc (int): Number of processes for panoptic quality computing. + Defaults to 32. When ``nproc`` exceeds the number of cpu cores, + the number of cpu cores is used. + file_client_args (dict, optional): Arguments to instantiate the + corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + default_prefix: Optional[str] = 'coco_panoptic' + + def __init__(self, + ann_file: Optional[str] = None, + seg_prefix: Optional[str] = None, + classwise: bool = False, + format_only: bool = False, + outfile_prefix: Optional[str] = None, + nproc: int = 32, + file_client_args: dict = None, + backend_args: dict = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None) -> None: + if panopticapi is None: + raise RuntimeError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') + + super().__init__(collect_device=collect_device, prefix=prefix) + self.classwise = classwise + self.format_only = format_only + if self.format_only: + assert outfile_prefix is not None, 'outfile_prefix must be not' + 'None when format_only is True, otherwise the result files will' + 'be saved to a temp directory which will be cleaned up at the end.' + + self.tmp_dir = None + # outfile_prefix should be a prefix of a path which points to a shared + # storage when train or test with multi nodes. + self.outfile_prefix = outfile_prefix + if outfile_prefix is None: + self.tmp_dir = tempfile.TemporaryDirectory() + self.outfile_prefix = osp.join(self.tmp_dir.name, 'results') + # the directory to save predicted panoptic segmentation mask + self.seg_out_dir = f'{self.outfile_prefix}.panoptic' + self.nproc = nproc + self.seg_prefix = seg_prefix + + self.cat_ids = None + self.cat2label = None + + self.backend_args = backend_args + if file_client_args is not None: + raise RuntimeError( + 'The `file_client_args` is deprecated, ' + 'please use `backend_args` instead, please refer to' + 'https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 + ) + + if ann_file: + with get_local_path( + ann_file, backend_args=self.backend_args) as local_path: + self._coco_api = COCOPanoptic(local_path) + self.categories = self._coco_api.cats + else: + self._coco_api = None + self.categories = None + + def __del__(self) -> None: + """Clean up.""" + if self.tmp_dir is not None: + self.tmp_dir.cleanup() + + def gt_to_coco_json(self, gt_dicts: Sequence[dict], + outfile_prefix: str) -> Tuple[str, str]: + """Convert ground truth to coco panoptic segmentation format json file. + + Args: + gt_dicts (Sequence[dict]): Ground truth of the dataset. + outfile_prefix (str): The filename prefix of the json file. If the + prefix is "somepath/xxx", the json file will be named + "somepath/xxx.gt.json". + + Returns: + Tuple[str, str]: The filename of the json file and the name of the\ + directory which contains panoptic segmentation masks. + """ + assert len(gt_dicts) > 0, 'gt_dicts is empty.' + gt_folder = osp.dirname(gt_dicts[0]['seg_map_path']) + converted_json_path = f'{outfile_prefix}.gt.json' + + categories = [] + for id, name in enumerate(self.dataset_meta['classes']): + isthing = 1 if name in self.dataset_meta['thing_classes'] else 0 + categories.append({'id': id, 'name': name, 'isthing': isthing}) + + image_infos = [] + annotations = [] + for gt_dict in gt_dicts: + img_id = gt_dict['image_id'] + image_info = { + 'id': img_id, + 'width': gt_dict['width'], + 'height': gt_dict['height'], + 'file_name': osp.split(gt_dict['seg_map_path'])[-1] + } + image_infos.append(image_info) + + pan_png = mmcv.imread(gt_dict['seg_map_path']).squeeze() + pan_png = pan_png[:, :, ::-1] + pan_png = rgb2id(pan_png) + segments_info = [] + for segment_info in gt_dict['segments_info']: + id = segment_info['id'] + label = segment_info['category'] + mask = pan_png == id + isthing = categories[label]['isthing'] + if isthing: + iscrowd = 1 if not segment_info['is_thing'] else 0 + else: + iscrowd = 0 + + new_segment_info = { + 'id': id, + 'category_id': label, + 'isthing': isthing, + 'iscrowd': iscrowd, + 'area': mask.sum() + } + segments_info.append(new_segment_info) + + segm_file = image_info['file_name'].replace("_leftImg8bit.png", "_panoptic.png") + annotation = dict( + image_id=img_id, + segments_info=segments_info, + file_name=segm_file) + annotations.append(annotation) + pan_png = id2rgb(pan_png) + + info = dict( + date_created=str(datetime.datetime.now()), + description='Coco json file converted by mmdet CocoPanopticMetric.' + ) + coco_json = dict( + info=info, + images=image_infos, + categories=categories, + licenses=None, + ) + if len(annotations) > 0: + coco_json['annotations'] = annotations + dump(coco_json, converted_json_path) + return converted_json_path, gt_folder + + def result2json(self, results: Sequence[dict], + outfile_prefix: str) -> Tuple[str, str]: + """Dump the panoptic results to a COCO style json file and a directory. + + Args: + results (Sequence[dict]): Testing results of the dataset. + outfile_prefix (str): The filename prefix of the json files and the + directory. + + Returns: + Tuple[str, str]: The json file and the directory which contains \ + panoptic segmentation masks. The filename of the json is + "somepath/xxx.panoptic.json" and name of the directory is + "somepath/xxx.panoptic". + """ + label2cat = dict((v, k) for (k, v) in self.cat2label.items()) + pred_annotations = [] + for idx in range(len(results)): + result = results[idx] + for segment_info in result['segments_info']: + sem_label = segment_info['category_id'] + # convert sem_label to json label + cat_id = label2cat[sem_label] + segment_info['category_id'] = label2cat[sem_label] + is_thing = self.categories[cat_id]['isthing'] + segment_info['isthing'] = is_thing + pred_annotations.append(result) + pan_json_results = dict(annotations=pred_annotations) + json_filename = f'{outfile_prefix}.panoptic.json' + dump(pan_json_results, json_filename) + return json_filename, ( + self.seg_out_dir + if self.tmp_dir is None else tempfile.gettempdir()) + + def _parse_predictions(self, + pred: dict, + img_id: int, + segm_file: str, + label2cat=None) -> dict: + """Parse panoptic segmentation predictions. + + Args: + pred (dict): Panoptic segmentation predictions. + img_id (int): Image id. + segm_file (str): Segmentation file name. + label2cat (dict): Mapping from label to category id. + Defaults to None. + + Returns: + dict: Parsed predictions. + """ + result = dict() + result['img_id'] = img_id + # shape (1, H, W) -> (H, W) + pan = pred['pred_panoptic_seg']['sem_seg'].cpu().numpy()[0] + ignore_index = pred['pred_panoptic_seg'].get( + 'ignore_index', len(self.dataset_meta['classes'])) + pan_labels = np.unique(pan) + segments_info = [] + for pan_label in pan_labels: + sem_label = pan_label % INSTANCE_OFFSET + # We reserve the length of dataset_meta['classes'] + # and ignore_index for VOID label + if sem_label == len( + self.dataset_meta['classes']) or sem_label == ignore_index: + continue + mask = pan == pan_label + area = mask.sum() + segments_info.append({ + 'id': + int(pan_label), + # when ann_file provided, sem_label should be cat_id, otherwise + # sem_label should be a continuous id, not the cat_id + # defined in dataset + 'category_id': + label2cat[sem_label] if label2cat else sem_label, + 'area': + int(area) + }) + # evaluation script uses 0 for VOID label. + pan[pan % INSTANCE_OFFSET == len(self.dataset_meta['classes'])] = VOID + pan[pan % INSTANCE_OFFSET == ignore_index] = VOID + + pan = id2rgb(pan).astype(np.uint8) + mmcv.imwrite(pan[:, :, ::-1], osp.join(self.seg_out_dir, segm_file)) + result = { + 'image_id': img_id, + 'segments_info': segments_info, + 'file_name': segm_file + } + + return result + + def _compute_batch_pq_stats(self, data_samples: Sequence[dict]): + """Process gts and predictions when ``outfile_prefix`` is not set, gts + are from dataset or a json file which is defined by ``ann_file``. + + Intermediate results, ``pq_stats``, are computed here and put into + ``self.results``. + """ + if self._coco_api is None: + categories = dict() + for id, name in enumerate(self.dataset_meta['classes']): + isthing = 1 if name in self.dataset_meta['thing_classes']\ + else 0 + categories[id] = {'id': id, 'name': name, 'isthing': isthing} + label2cat = None + else: + categories = self.categories + cat_ids = self._coco_api.get_cat_ids( + cat_names=self.dataset_meta['classes']) + label2cat = {i: cat_id for i, cat_id in enumerate(cat_ids)} + + for data_sample in data_samples: + # parse pred + img_id = data_sample['img_id'] + # segm_file = osp.basename(data_sample['img_path']).replace('jpg', 'png') + segm_file = osp.basename(data_sample['img_path']).replace("_leftImg8bit.png", "_panoptic.png") + segm_file = os.path.join(os.path.basename(os.path.dirname(data_sample['img_path'])), segm_file) + result = self._parse_predictions( + pred=data_sample, + img_id=img_id, + segm_file=segm_file, + label2cat=label2cat) + + # parse gt + gt = dict() + gt['image_id'] = img_id + gt['width'] = data_sample['ori_shape'][1] + gt['height'] = data_sample['ori_shape'][0] + gt['file_name'] = segm_file + + if self._coco_api is None: + # get segments_info from data_sample + seg_map_path = osp.join(self.seg_prefix, segm_file) + pan_png = mmcv.imread(seg_map_path).squeeze() + pan_png = pan_png[:, :, ::-1] + pan_png = rgb2id(pan_png) + segments_info = [] + + for segment_info in data_sample['segments_info']: + id = segment_info['id'] + label = segment_info['category'] + mask = pan_png == id + isthing = categories[label]['isthing'] + if isthing: + iscrowd = 1 if not segment_info['is_thing'] else 0 + else: + iscrowd = 0 + + new_segment_info = { + 'id': id, + 'category_id': label, + 'isthing': isthing, + 'iscrowd': iscrowd, + 'area': mask.sum() + } + segments_info.append(new_segment_info) + else: + # get segments_info from annotation file + segments_info = self._coco_api.imgToAnns[img_id] + + gt['segments_info'] = segments_info + + pq_stats = pq_compute_single_core( + proc_id=0, + annotation_set=[(gt, result)], + gt_folder=self.seg_prefix, + pred_folder=self.seg_out_dir, + categories=categories, + backend_args=self.backend_args) + + self.results.append(pq_stats) + + def _process_gt_and_predictions(self, data_samples: Sequence[dict]): + """Process gts and predictions when ``outfile_prefix`` is set. + + The predictions will be saved to directory specified by + ``outfile_predfix``. The matched pair (gt, result) will be put into + ``self.results``. + """ + for data_sample in data_samples: + # parse pred + img_id = data_sample['img_id'] + segm_file = osp.basename(data_sample['img_path']).replace("_leftImg8bit.png", "_panoptic.png") + result = self._parse_predictions( + pred=data_sample, img_id=img_id, segm_file=segm_file) + + # parse gt + gt = dict() + gt['image_id'] = img_id + gt['width'] = data_sample['ori_shape'][1] + gt['height'] = data_sample['ori_shape'][0] + + if self._coco_api is None: + # get segments_info from dataset + gt['segments_info'] = data_sample['segments_info'] + gt['seg_map_path'] = data_sample['seg_map_path'] + + self.results.append((gt, result)) + + # TODO: data_batch is no longer needed, consider adjusting the + # parameter position + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ + # If ``self.tmp_dir`` is none, it will save gt and predictions to + # self.results, otherwise, it will compute pq_stats here. + if self.tmp_dir is None: + self._process_gt_and_predictions(data_samples) + else: + self._compute_batch_pq_stats(data_samples) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. There + are two cases: + + - When ``outfile_prefix`` is not provided, the elements in + results are pq_stats which can be summed directly to get PQ. + - When ``outfile_prefix`` is provided, the elements in + results are tuples like (gt, pred). + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + logger: MMLogger = MMLogger.get_current_instance() + + if self.tmp_dir is None: + # do evaluation after collect all the results + + # split gt and prediction list + gts, preds = zip(*results) + + if self._coco_api is None: + # use converted gt json file to initialize coco api + logger.info('Converting ground truth to coco format...') + coco_json_path, gt_folder = self.gt_to_coco_json( + gt_dicts=gts, outfile_prefix=self.outfile_prefix) + self._coco_api = COCOPanoptic(coco_json_path) + else: + gt_folder = self.seg_prefix + + self.cat_ids = self._coco_api.get_cat_ids( + cat_names=self.dataset_meta['classes']) + self.cat2label = { + cat_id: i + for i, cat_id in enumerate(self.cat_ids) + } + self.img_ids = self._coco_api.get_img_ids() + self.categories = self._coco_api.cats + + # convert predictions to coco format and dump to json file + json_filename, pred_folder = self.result2json( + results=preds, outfile_prefix=self.outfile_prefix) + + if self.format_only: + logger.info('results are saved in ' + f'{osp.dirname(self.outfile_prefix)}') + return dict() + + imgs = self._coco_api.imgs + gt_json = self._coco_api.img_ann_map + gt_json = [{ + 'image_id': k, + 'segments_info': v, + 'file_name': imgs[k]['segm_file'] + } for k, v in gt_json.items()] + pred_json = load(json_filename) + pred_json = dict( + (el['image_id'], el) for el in pred_json['annotations']) + + # match the gt_anns and pred_anns in the same image + matched_annotations_list = [] + for gt_ann in gt_json: + img_id = gt_ann['image_id'] + if img_id not in pred_json.keys(): + raise Exception('no prediction for the image' + ' with id: {}'.format(img_id)) + matched_annotations_list.append((gt_ann, pred_json[img_id])) + + pq_stat = pq_compute_multi_core( + matched_annotations_list, + gt_folder, + pred_folder, + self.categories, + backend_args=self.backend_args, + nproc=self.nproc) + + else: + # aggregate the results generated in process + if self._coco_api is None: + categories = dict() + for id, name in enumerate(self.dataset_meta['classes']): + isthing = 1 if name in self.dataset_meta[ + 'thing_classes'] else 0 + categories[id] = { + 'id': id, + 'name': name, + 'isthing': isthing + } + self.categories = categories + + pq_stat = PQStat() + for result in results: + pq_stat += result + + metrics = [('All', None), ('Things', True), ('Stuff', False)] + pq_results = {} + + for name, isthing in metrics: + pq_results[name], classwise_results = pq_stat.pq_average( + self.categories, isthing=isthing) + if name == 'All': + pq_results['classwise'] = classwise_results + + classwise_results = None + if self.classwise: + classwise_results = { + k: v + for k, v in zip(self.dataset_meta['classes'], + pq_results['classwise'].values()) + } + + print_panoptic_table(pq_results, classwise_results, logger=logger) + results = parse_pq_results(pq_results) + + return results + + +def parse_pq_results(pq_results: dict) -> dict: + """Parse the Panoptic Quality results. + + Args: + pq_results (dict): Panoptic Quality results. + + Returns: + dict: Panoptic Quality results parsed. + """ + result = dict() + result['PQ'] = 100 * pq_results['All']['pq'] + result['SQ'] = 100 * pq_results['All']['sq'] + result['RQ'] = 100 * pq_results['All']['rq'] + result['PQ_th'] = 100 * pq_results['Things']['pq'] + result['SQ_th'] = 100 * pq_results['Things']['sq'] + result['RQ_th'] = 100 * pq_results['Things']['rq'] + result['PQ_st'] = 100 * pq_results['Stuff']['pq'] + result['SQ_st'] = 100 * pq_results['Stuff']['sq'] + result['RQ_st'] = 100 * pq_results['Stuff']['rq'] + return result + + +def print_panoptic_table( + pq_results: dict, + classwise_results: Optional[dict] = None, + logger: Optional[Union['MMLogger', str]] = None) -> None: + """Print the panoptic evaluation results table. + + Args: + pq_results(dict): The Panoptic Quality results. + classwise_results(dict, optional): The classwise Panoptic Quality. + results. The keys are class names and the values are metrics. + Defaults to None. + logger (:obj:`MMLogger` | str, optional): Logger used for printing + related information during evaluation. Default: None. + """ + + headers = ['', 'PQ', 'SQ', 'RQ', 'categories'] + data = [headers] + for name in ['All', 'Things', 'Stuff']: + numbers = [ + f'{(pq_results[name][k] * 100):0.3f}' for k in ['pq', 'sq', 'rq'] + ] + row = [name] + numbers + [pq_results[name]['n']] + data.append(row) + table = AsciiTable(data) + print_log('Panoptic Evaluation Results:\n' + table.table, logger=logger) + + if classwise_results is not None: + class_metrics = [(name, ) + tuple(f'{(metrics[k] * 100):0.3f}' + for k in ['pq', 'sq', 'rq']) + for name, metrics in classwise_results.items()] + num_columns = min(8, len(class_metrics) * 4) + results_flatten = list(itertools.chain(*class_metrics)) + headers = ['category', 'PQ', 'SQ', 'RQ'] * (num_columns // 4) + results_2d = itertools.zip_longest( + *[results_flatten[i::num_columns] for i in range(num_columns)]) + data = [headers] + data += [result for result in results_2d] + table = AsciiTable(data) + print_log( + 'Classwise Panoptic Evaluation Results:\n' + table.table, + logger=logger) diff --git a/seg/evaluation/metrics/ins_cls_iou_metric.py b/seg/evaluation/metrics/ins_cls_iou_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..8f86d30f8e73985346ae835f02b3b0ff2a23c397 --- /dev/null +++ b/seg/evaluation/metrics/ins_cls_iou_metric.py @@ -0,0 +1,171 @@ +import os + +import mmcv +import torch +from mmengine.dist import broadcast_object_list, collect_results, is_main_process + +from typing import Dict, Optional, Sequence +from mmengine.evaluator import BaseMetric +from mmdet.registry import METRICS +from mmengine.evaluator.metric import _to_cpu +from mmengine.visualization import Visualizer + + +@METRICS.register_module() +class InsClsIoUMetric(BaseMetric): + + def __init__(self, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + base_classes=None, + novel_classes=None, + with_score=True, + output_failure=False, + ) -> None: + + super().__init__(collect_device=collect_device, prefix=prefix) + self.scores = [] + self.iou_list = [] + + self.base_scores = [] + self.novel_scores = [] + self.base_iou_list = [] + self.novel_iou_list = [] + + self.with_score = with_score + + if base_classes is not None: + assert novel_classes is not None + num_classes = max(max(base_classes) + 1, max(novel_classes) + 1) + self.base_novel_indicator = torch.zeros((num_classes,), dtype=torch.long) + for clss in base_classes: + self.base_novel_indicator[clss] = 1 + for clss in novel_classes: + self.base_novel_indicator[clss] = 2 + else: + self.base_novel_indicator = None + + self.output_failure = output_failure + + def get_iou(self, gt_masks, pred_masks): + gt_masks = gt_masks + n, h, w = gt_masks.shape + intersection = (gt_masks & pred_masks).reshape(n, h * w).sum(dim=-1) + union = (gt_masks | pred_masks).reshape(n, h * w).sum(dim=-1) + ious = (intersection / (union + 1.e-8)) + return ious + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + for data_sample in data_samples: + gt_labels = data_sample['gt_instances']['labels'] + if len(gt_labels) == 0: + score = gt_labels.new_zeros((0,), dtype=torch.float) + ious = gt_labels.new_zeros((0,), dtype=torch.float) + else: + if self.with_score: + if self.base_novel_indicator is not None: + assert (self.base_novel_indicator[gt_labels.cpu()] > 0).all() + pred_labels = data_sample['pred_instances']['labels'] + score = (pred_labels == gt_labels).to(dtype=torch.float) * 100 + if 'masks' in data_sample['pred_instances']: + pred_masks = data_sample['pred_instances']['masks'] + if self.output_failure: + for idx, _score in enumerate(score.cpu().numpy().tolist()): + if _score == 0.: + img_path = data_sample['img_path'] + vis = Visualizer() + rgb_img = mmcv.imread(img_path) + rgb_img = mmcv.bgr2rgb(rgb_img) + vis.set_image(rgb_img) + masks = pred_masks[idx] + # colors = [(0, 176, 237)] + colors = [(250, 177, 135)] + vis.draw_binary_masks(masks, alphas=.85, colors=colors) + vis_res = vis.get_image() + if vis_res is None: + continue + img_name = os.path.basename(img_path) + mmcv.imwrite( + mmcv.rgb2bgr(vis_res), os.path.join( + 'failure_lvis', + img_name.split('.')[0] + '_' + str(idx) + '_' + str(int(gt_labels[idx])) + + '_' + str(int(pred_labels[idx])) + '.jpg') + ) + gt_masks = data_sample['gt_instances']['masks'] + gt_masks = gt_masks.to_tensor(dtype=torch.bool, device=pred_masks.device) + ious = self.get_iou(gt_masks, pred_masks) + else: + ious = gt_labels.new_tensor([0.]) + self.iou_list.append(ious.to(device='cpu')) + if self.base_novel_indicator is not None: + self.base_iou_list.append(ious[self.base_novel_indicator[gt_labels.cpu()] == 1].to(device='cpu')) + self.novel_iou_list.append(ious[self.base_novel_indicator[gt_labels.cpu()] == 2].to(device='cpu')) + if self.with_score: + self.scores.append(score.to(device='cpu')) + if self.base_novel_indicator is not None: + self.base_scores.append(score[self.base_novel_indicator[gt_labels.cpu()] == 1].to(device='cpu')) + self.novel_scores.append(score[self.base_novel_indicator[gt_labels.cpu()] == 2].to(device='cpu')) + + def compute_metrics(self, scores, ious, + base_scores, base_ious, + novel_scores, novel_ious) -> Dict[str, float]: + + iou = ious.mean().item() + results = dict() + results['miou'] = iou + if self.base_novel_indicator is not None: + results['base_iou'] = base_ious.mean().item() + + results['novel_iou'] = novel_ious.mean().item() + if self.with_score: + score = scores.mean().item() + results['score'] = score + if base_scores is not None: + results['base_score'] = base_scores.mean().item() + results['novel_score'] = novel_scores.mean().item() + return results + + def evaluate(self, size: int) -> dict: + _ious = collect_results(self.iou_list, size, self.collect_device) + if self.base_novel_indicator is not None: + _base_ious = collect_results(self.base_iou_list, size, self.collect_device) + _novel_ious = collect_results(self.novel_iou_list, size, self.collect_device) + if self.with_score: + _scores = collect_results(self.scores, size, self.collect_device) + if self.base_novel_indicator is not None: + _base_scores = collect_results(self.base_scores, size, self.collect_device) + + _novel_scores = collect_results(self.novel_scores, size, self.collect_device) + + if is_main_process(): + if self.base_novel_indicator is not None: + base_ious = torch.cat(_base_ious) + novel_ious = torch.cat(_novel_ious) + else: + base_ious = None + novel_ious = None + if self.with_score: + scores = torch.cat(_scores) + scores = _to_cpu(scores) + if self.base_novel_indicator is not None: + base_scores = torch.cat(_base_scores) + novel_scores = torch.cat(_novel_scores) + else: + base_scores = None + novel_scores = None + else: + scores = None + base_scores = None + novel_scores = None + ious = torch.cat(_ious) + ious = _to_cpu(ious) + _metrics = self.compute_metrics( + scores, ious, + base_scores, base_ious, + novel_scores, novel_ious + ) + metrics = [_metrics] + else: + metrics = [None] # type: ignore + broadcast_object_list(metrics) + return metrics[0] diff --git a/seg/evaluation/metrics/vip_seg_metric.py b/seg/evaluation/metrics/vip_seg_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..1e9b029be7375a6fa492ba6d6e8fb0774e73467f --- /dev/null +++ b/seg/evaluation/metrics/vip_seg_metric.py @@ -0,0 +1,334 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from collections import defaultdict +from typing import Dict, List, Optional, Sequence, Union +import numpy as np +import torch +from mmdet.evaluation.metrics.coco_panoptic_metric import print_panoptic_table, parse_pq_results +from mmengine import print_log, mkdir_or_exist +from mmengine.dist import barrier, broadcast_object_list, is_main_process +from mmdet.registry import METRICS +from mmdet.evaluation.metrics.base_video_metric import BaseVideoMetric, collect_tracking_results +from panopticapi.evaluation import PQStat + +from seg.models.utils import mmpan2hbpan, INSTANCE_OFFSET_HB, mmgt2hbpan +from seg.models.utils import cal_pq, NO_OBJ_ID, IoUObj + + +def parse_pan_map_hb(pan_map: np.ndarray, data_sample: dict, num_classes: int) -> dict: + result = dict() + result['video_id'] = data_sample['video_id'] + result['frame_id'] = data_sample['frame_id'] + + # For video evaluation, each map may include several loads, + # it is not efficient for saving an extra png map, especially + # for machines not with high performance ssd. + pan_labels = np.unique(pan_map) + segments_info = [] + for pan_label in pan_labels: + sem_label = pan_label // INSTANCE_OFFSET_HB + if sem_label >= num_classes: + continue + mask = (pan_map == pan_label).astype(np.uint8) + area = mask.sum() + # _mask = maskUtils.encode(np.asfortranarray(mask)) + # _mask['counts'] = _mask['counts'].decode() + segments_info.append({ + 'id': int(pan_label), + 'category_id': sem_label, + 'area': int(area), + 'mask': mask + }) + result['segments_info'] = segments_info + + return result + + +def parse_data_sample_gt(data_sample: dict, num_things: int, num_stuff: int) -> dict: + num_classes = num_things + num_stuff + result = dict() + result['video_id'] = data_sample['video_id'] + result['frame_id'] = data_sample['frame_id'] + + # For video evaluation, each map may include several loads, + # it is not efficient for saving an extra png map, especially + # for machines not with high performance ssd. + gt_instances = data_sample['gt_instances'] + segments_info = [] + for thing_id in range(len(gt_instances['labels'])): + mask = gt_instances['masks'].masks[thing_id].astype(np.uint8) + area = mask.sum() + pan_id = gt_instances['instances_ids'][thing_id] + cat = int(gt_instances['labels'][thing_id]) + if cat >= num_things: + raise ValueError(f"not reasonable value {cat}") + # _mask = maskUtils.encode(np.asfortranarray(mask)) + # _mask['counts'] = _mask['counts'].decode() + segments_info.append({ + 'id': int(pan_id), + 'category_id': cat, + 'area': int(area), + 'mask': mask + }) + + gt_sem_seg = data_sample['gt_sem_seg']['sem_seg'][0].cpu().numpy() + for stuff_id in np.unique(gt_sem_seg): + if stuff_id < num_things: + continue + if stuff_id >= num_classes: + assert stuff_id == NO_OBJ_ID // INSTANCE_OFFSET_HB + _mask = (gt_sem_seg == stuff_id).astype(np.uint8) + area = _mask.sum() + cat = int(stuff_id) + pan_id = cat * INSTANCE_OFFSET_HB + segments_info.append({ + 'id': int(pan_id), + 'category_id': cat, + 'area': int(area), + 'mask': _mask + }) + + if segments_info[-1]['id'] != NO_OBJ_ID: + segments_info.append({ + 'id': int(NO_OBJ_ID), + 'category_id': NO_OBJ_ID // INSTANCE_OFFSET_HB, + 'area': 0, + 'mask': np.zeros_like(gt_sem_seg, dtype=np.uint8) + }) + result['segments_info'] = segments_info + return result + + +@METRICS.register_module() +class VIPSegMetric(BaseVideoMetric): + """mAP evaluation metrics for the VIS task. + + Args: + metric (str | list[str]): Metrics to be evaluated. + Default value is `youtube_vis_ap`.. + outfile_prefix (str | None): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonyms metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Default: None + format_only (bool): If True, only formatting the results to the + official format and not performing evaluation. Defaults to False. + """ + + default_prefix: Optional[str] = 'vip_seg' + + def __init__(self, + metric: Union[str, List[str]] = 'VPQ@1', + outfile_prefix: Optional[str] = None, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + format_only: bool = False) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + # vis evaluation metrics + self.metrics = metric if isinstance(metric, list) else [metric] + self.format_only = format_only + allowed_metrics = ['VPQ'] + for metric in self.metrics: + if metric not in allowed_metrics and metric.split('@')[0] not in allowed_metrics: + raise KeyError( + f"metric should be 'youtube_vis_ap', but got {metric}.") + + self.outfile_prefix = outfile_prefix + self.per_video_res = [] + self.categories = {} + self._vis_meta_info = defaultdict(list) # record video and image infos + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + for track_data_sample in data_samples: + video_data_samples = track_data_sample['video_data_samples'] + ori_video_len = video_data_samples[0].ori_video_length + if ori_video_len == len(video_data_samples): + # video process + self.process_video(video_data_samples) + else: + # image process + raise NotImplementedError + + def process_video(self, data_samples): + video_length = len(data_samples) + + num_things = len(self.dataset_meta['thing_classes']) + num_stuff = len(self.dataset_meta['stuff_classes']) + num_classes = num_things + num_stuff + for frame_id in range(video_length): + img_data_sample = data_samples[frame_id].to_dict() + # 0 is for dummy dimension in fusion head, not batch. + pred = mmpan2hbpan(img_data_sample['pred_track_panoptic_seg']['sem_seg'][0], num_classes=num_classes) + + if self.format_only: + vid_id = data_samples[frame_id].video_id + gt = mmgt2hbpan(data_samples[frame_id]) + mkdir_or_exist('vipseg_output/gt/') + mkdir_or_exist('vipseg_output/pred/') + torch.save(gt.to(device='cpu'), + 'vipseg_output/gt/{:06d}_{:06d}.pth'.format(vid_id, frame_id)) + torch.save(torch.tensor(pred, device='cpu'), + 'vipseg_output/pred/{:06d}_{:06d}.pth'.format(vid_id, frame_id)) + continue + + pred_json = parse_pan_map_hb(pred, img_data_sample, num_classes=num_classes) + gt_json = parse_data_sample_gt(img_data_sample, num_things=num_things, num_stuff=num_stuff) + self.per_video_res.append((pred_json, gt_json)) + + if self.format_only: + return + + video_results = [] + for pred, gt in self.per_video_res: + intersection_info = dict() + gt_no_obj_info = gt['segments_info'][-1] + for pred_seg_info in pred['segments_info']: + intersection = int((gt_no_obj_info['mask'] * pred_seg_info['mask']).sum()) + union = pred_seg_info['area'] + intersection_info[gt_no_obj_info['id'], pred_seg_info['id']] = IoUObj( + intersection=intersection, + union=union + ) + for pred_seg_info in pred['segments_info']: + for gt_seg_info in gt['segments_info'][:-1]: + intersection = int((gt_seg_info['mask'] * pred_seg_info['mask']).sum()) + union = gt_seg_info['area'] + pred_seg_info['area'] - \ + intersection - intersection_info[NO_OBJ_ID, pred_seg_info['id']].intersection + intersection_info[gt_seg_info['id'], pred_seg_info['id']] = IoUObj( + intersection=intersection, + union=union + ) + video_results.append(intersection_info) + self.per_video_res.clear() + self.results.append(video_results) + + def compute_metrics(self, results: List) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (List): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. + """ + # split gt and prediction list + eval_results = {} + if self.format_only: + return eval_results + for metric in self.metrics: + seq_len = int(metric.split('@')[-1]) + pq_stat = PQStat() + cnt = 0 + for vid_idx, video_instances in enumerate(results): + for frame_x in range(len(video_instances)): + if frame_x + seq_len > len(video_instances): + break + global_intersection_info = defaultdict(IoUObj) + for frame_offset in range(seq_len): + frame_info = video_instances[frame_x + frame_offset] + for gt_id, pred_id in frame_info: + global_intersection_info[gt_id, pred_id] += frame_info[gt_id, pred_id] + pq_stat += cal_pq(global_intersection_info, classes=self.dataset_meta['classes']) + # global_intersection_info = defaultdict(IoUObj) + # for frame_idx, frame_info in enumerate(video_instances): + # for gt_id, pred_id in frame_info: + # global_intersection_info[gt_id, pred_id] += frame_info[gt_id, pred_id] + # if frame_idx - seq_len >= 0: + # out_frame_info = video_instances[frame_idx - seq_len] + # for gt_id, pred_id in out_frame_info: + # global_intersection_info[gt_id, pred_id] -= out_frame_info[gt_id, pred_id] + # assert global_intersection_info[gt_id, pred_id].is_legal() + # if frame_idx - seq_len >= -1: + # pq_stat += cal_pq(global_intersection_info, classes=self.dataset_meta['classes']) + # cnt += 1 + print_log("Total calculated clips: " + str(cnt), logger='current') + + sub_metrics = [('All', None), ('Things', True), ('Stuff', False)] + pq_results = {} + + for name, isthing in sub_metrics: + pq_results[name], classwise_results = pq_stat.pq_average( + self.categories, isthing=isthing) + if name == 'All': + pq_results['classwise'] = classwise_results + + # classwise_results = { + # k: v + # for k, v in zip(self.dataset_meta['classes'], + # pq_results['classwise'].values()) + # } + + print_panoptic_table(pq_results, None, logger='current') + metric_results = parse_pq_results(pq_results) + for key in metric_results: + eval_results[metric + f'_{key}'] = metric_results[key] + return eval_results + + def evaluate(self, size: int) -> dict: + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. + + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are the + names of the metrics, and the values are corresponding results. + """ + # wait for all processes to complete prediction. + barrier() + + cls_idx = 0 + for thing_cls in self.dataset_meta['thing_classes']: + self.categories[cls_idx] = {'class': thing_cls, 'isthing': 1} + cls_idx += 1 + for stuff_cls in self.dataset_meta['stuff_classes']: + self.categories[cls_idx] = {'class': stuff_cls, 'isthing': 0} + cls_idx += 1 + assert cls_idx == len(self.dataset_meta['classes']) + + if len(self.results) == 0: + warnings.warn( + f'{self.__class__.__name__} got empty `self.results`. Please ' + 'ensure that the processed results are properly added into ' + '`self.results` in `process` method.') + + results = collect_tracking_results(self.results, self.collect_device) + + # # gather seq_info + # gathered_seq_info = all_gather_object(self._vis_meta_info['videos']) + # all_seq_info = [] + # for _seq_info in gathered_seq_info: + # all_seq_info.extend(_seq_info) + # # update self._vis_meta_info + # self._vis_meta_info = dict(videos=all_seq_info) + + if is_main_process(): + print_log( + f"There are totally {len(results)} videos to be evaluated.", + logger='current' + ) + _metrics = self.compute_metrics(results) # type: ignore + # Add prefix to metric names + if self.prefix: + _metrics = { + '/'.join((self.prefix, k)): v + for k, v in _metrics.items() + } + metrics = [_metrics] + else: + metrics = [None] # type: ignore + + broadcast_object_list(metrics) + + # reset the results list + self.results.clear() + # reset the vis_meta_info + self._vis_meta_info.clear() + return metrics[0] diff --git a/seg/evaluation/metrics/vos_metric.py b/seg/evaluation/metrics/vos_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..116a3e5744d34c461343a377c8866186af5f2c68 --- /dev/null +++ b/seg/evaluation/metrics/vos_metric.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path +from collections import defaultdict +from typing import Dict, List, Optional, Sequence + +import mmcv +import numpy as np +from mmengine import mkdir_or_exist +from mmengine.dist import barrier +from mmdet.registry import METRICS +from mmdet.evaluation.metrics.base_video_metric import BaseVideoMetric + +PALETTE = { + 'davis': b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0', + 'mose': b'\x00\x00\x00\xe4\x1a\x1c7~\xb8M\xafJ\x98N\xa3\xff\x7f\x00\xff\xff3\xa6V(\xf7\x81\xbf\x99\x99\x99f\xc2\xa5\xfc\x8db\x8d\xa0\xcb\xe7\x8a\xc3\xa6\xd8T\xff\xd9/\xe5\xc4\x94\xb3\xb3\xb3\x8d\xd3\xc7\xff\xff\xb3\xbe\xba\xda\xfb\x80r\x80\xb1\xd3\xfd\xb4b\xb3\xdei\xfc\xcd\xe5\xd9\xd9\xd9\xbc\x80\xbd\xcc\xeb\xc5\xff\xedo', +} + + +@METRICS.register_module() +class VOSMetric(BaseVideoMetric): + """mAP evaluation metrics for the VIS task. + + Args: + metric (str | list[str]): Metrics to be evaluated. + Default value is `youtube_vis_ap`.. + outfile_prefix (str | None): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Defaults to None. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonyms metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Default: None + format_only (bool): If True, only formatting the results to the + official format and not performing evaluation. Defaults to False. + """ + + default_prefix: Optional[str] = 'vip_seg' + + def __init__(self, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + format_only: bool = False, + palette: Optional[str] = None, + results_path: str = 'DAVIS' + ) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + self.format_only = format_only + if palette is not None: + self.palette = PALETTE[palette] + else: + self.palette = None + self.results_path = results_path + + self.per_video_res = [] + self.categories = {} + self._vis_meta_info = defaultdict(list) # record video and image infos + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + for track_data_sample in data_samples: + video_data_samples = track_data_sample['video_data_samples'] + if 'pred_track_proposal' not in video_data_samples[0]: + continue + ori_video_len = video_data_samples[0].ori_video_length + if ori_video_len == len(video_data_samples): + # video process + self.process_video(video_data_samples) + else: + # image process + raise NotImplementedError + + def process_video(self, data_samples): + video_length = len(data_samples) + mkdir_or_exist(self.results_path) + for frame_id in range(video_length): + img_data_sample = data_samples[frame_id].to_dict() + pred = img_data_sample['pred_track_proposal'] + h, w = pred.shape + pred_map = np.zeros((h, w, 3), dtype=np.uint8) + for ins_id in np.unique(pred): + if ins_id == 0: + continue + r = ins_id // 1000000 + g = (ins_id % 1000000) // 1000 + b = ins_id % 1000 + pred_map[pred == ins_id] = np.array([r, g, b], dtype=np.uint8) + ori_img_path = data_samples[frame_id].img_path + folder_name = os.path.basename(os.path.dirname(ori_img_path)) + file_name = os.path.basename(ori_img_path) + file_name = file_name.replace('.jpg', '.png') + if self.palette is not None: + from PIL import Image + pred_map = mmcv.bgr2rgb(pred_map) + pil_image = Image.fromarray(pred_map) + pil_image = pil_image.convert('P', palette=self.palette) + out_path = os.path.join(self.results_path, folder_name, file_name) + mkdir_or_exist(os.path.dirname(out_path)) + pil_image.save(out_path) + else: + mmcv.imwrite(pred_map, os.path.join(self.results_path, folder_name, file_name)) + + def compute_metrics(self, results: List) -> Dict[str, float]: + return {} + + def evaluate(self, size: int) -> dict: + """Evaluate the model performance of the whole dataset after processing + all batches. + + Args: + size (int): Length of the entire validation dataset. + + Returns: + dict: Evaluation metrics dict on the val dataset. The keys are the + names of the metrics, and the values are corresponding results. + """ + # wait for all processes to complete prediction. + barrier() + metrics = self.compute_metrics([]) + return metrics diff --git a/seg/models/backbones/__init__.py b/seg/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97b5ce4055421c879569b8e7c55a3849a6408985 --- /dev/null +++ b/seg/models/backbones/__init__.py @@ -0,0 +1,2 @@ +from .openclip_backbone import OpenCLIPBackbone +from .openclip_backbone import OpenCLIPBackboneText diff --git a/seg/models/backbones/openclip_backbone.py b/seg/models/backbones/openclip_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..4d8e5f711065f3f6afee75456ba531305d2c0145 --- /dev/null +++ b/seg/models/backbones/openclip_backbone.py @@ -0,0 +1,358 @@ +from typing import Optional, List + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from mmdet.registry import MODELS + +from mmengine.model import BaseModule +from mmengine.dist import get_dist_info +from mmengine.logging import MMLogger +from timm.layers import resample_abs_pos_embed + +import ext.open_clip as open_clip +from seg.models.utils.load_checkpoint import load_checkpoint_with_prefix + + +def flatten_permute(x): + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + return x + + +@MODELS.register_module() +class OpenCLIPBackbone(BaseModule): + """OpenCLIPBackbone, + Please refer to: + https://github.com/mlfoundations/open_clip/tree/5f7892b672b21e6853d0f6c11b18dda9bcf36c8d#pretrained-model-interface + for the supported models and checkpoints. + """ + STAGES = 4 + + def __init__( + self, + img_size: int = 1024, + model_name: str = '', + fix: bool = True, + fix_layers: Optional[List] = None, + init_cfg=None, + ): + assert init_cfg is not None and init_cfg['type'] in ['clip_pretrain', 'image_pretrain', 'Pretrained'], \ + f"{init_cfg['type']} is not supported." + pretrained = init_cfg['checkpoint'] + super().__init__(init_cfg=None) + self.init_cfg = init_cfg + self.logger = MMLogger.get_current_instance() + rank, world_size = get_dist_info() + + if world_size > 1: + if rank == 0: + if init_cfg['type'] == 'clip_pretrain': + _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, + return_transform=False, logger=self.logger) + elif init_cfg['type'] == 'image_pretrain': + _ = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) + + else: + pass + dist.barrier() + + # Get the clip model + if init_cfg['type'] == 'clip_pretrain': + clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, + return_transform=False, logger=self.logger) + elif init_cfg['type'] == 'image_pretrain': + clip_model = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger) + elif init_cfg['type'] == 'Pretrained': + clip_model = open_clip.create_model(model_name, pretrained_image=False, logger=self.logger) + else: + raise NotImplementedError + + self.out_indices = (0, 1, 2, 3) + model_name_lower = model_name.lower() + if 'convnext_' in model_name_lower: + model_type = 'convnext' + if '_base' in model_name_lower: + output_channels = [128, 256, 512, 1024] + feat_size = 0 + elif '_large' in model_name_lower: + output_channels = [192, 384, 768, 1536] + feat_size = 0 + elif '_xxlarge' in model_name_lower: + output_channels = [384, 768, 1536, 3072] + feat_size = 0 + else: + raise NotImplementedError(f"{model_name} not supported yet.") + elif 'rn' in model_name_lower: + model_type = 'resnet' + if model_name_lower.replace('-quickgelu', '') in ['rn50', 'rn101']: + output_channels = [256, 512, 1024, 2048] + feat_size = 7 + elif model_name_lower == 'rn50x4': + output_channels = [320, 640, 1280, 2560] + feat_size = 9 + elif model_name_lower == 'rn50x16': + output_channels = [384, 768, 1536, 3072] + feat_size = 12 + elif model_name_lower == 'rn50x64': + output_channels = [512, 1024, 2048, 4096] + feat_size = 14 + else: + raise NotImplementedError(f"{model_name} not supported yet.") + elif "vit" in model_name_lower: + model_type = 'vit' + if model_name_lower == 'vit-l-14': + output_channels = [1024, 1024, 1024, 1024] + feat_size = 0 + assert not clip_model.visual.input_patchnorm + assert clip_model.visual.attn_pool is None + else: + raise NotImplementedError(f"{model_name} not supported yet.") + else: + raise NotImplementedError(f"{model_name} not supported yet.") + + self.model_name = model_name + self.fix = fix + self.model_type = model_type + self.output_channels = output_channels + self.feat_size = feat_size + + # Get the visual model + if self.model_type == 'resnet': + self.stem = nn.Sequential(*[ + clip_model.visual.conv1, clip_model.visual.bn1, clip_model.visual.act1, + clip_model.visual.conv2, clip_model.visual.bn2, clip_model.visual.act2, + clip_model.visual.conv3, clip_model.visual.bn3, clip_model.visual.act3, + ]) + elif self.model_type == 'convnext': + self.stem = clip_model.visual.trunk.stem + elif self.model_type == 'vit': + self.stem = clip_model.visual.conv1 + else: + raise ValueError + + if self.model_type == 'resnet': + self.avgpool = clip_model.visual.avgpool + elif self.model_type == 'convnext': + self.avgpool = nn.Identity() + elif self.model_type == 'vit': + self.avgpool = flatten_permute + else: + raise ValueError + + self.res_layers = [] + if self.model_type in ['vit']: + self.t_class_embedding = clip_model.visual.class_embedding + self.t_positional_embedding = clip_model.visual.positional_embedding + self.t_ln_pre_trans = clip_model.visual.ln_pre + self.t_transformer = clip_model.visual.transformer + else: + for i in range(self.STAGES): + if self.model_type == 'resnet': + layer_name = f'layer{i + 1}' + layer = getattr(clip_model.visual, layer_name) + elif self.model_type == 'convnext': + layer_name = f'layer{i + 1}' + layer = clip_model.visual.trunk.stages[i] + else: + raise ValueError + self.add_module(layer_name, layer) + self.res_layers.append(layer_name) + + if self.model_type == 'resnet': + self.norm_pre = nn.Identity() + elif self.model_type == 'convnext': + self.norm_pre = clip_model.visual.trunk.norm_pre + elif self.model_type == 'vit': + self.norm_pre = nn.Identity() + + if self.model_type == 'resnet': + self.head = clip_model.visual.attnpool + elif self.model_type == 'convnext': + self.head = nn.Sequential(*[ + clip_model.visual.trunk.head, + clip_model.visual.head, + ]) + elif self.model_type == 'vit': + self.head = clip_model.visual.ln_post + + if self.init_cfg['type'] == 'Pretrained': + checkpoint_path = pretrained + state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix']) + self.load_state_dict(state_dict, strict=True) + + self.fix_layers = fix_layers + + if not self.fix: + self.train() + for name, param in self.norm_pre.named_parameters(): + param.requires_grad = False + for name, param in self.head.named_parameters(): + param.requires_grad = False + if self.fix_layers is not None: + for i, layer_name in enumerate(self.res_layers): + if i in self.fix_layers: + res_layer = getattr(self, layer_name) + for name, param in res_layer.named_parameters(): + param.requires_grad = False + if i == 0: + for name, param in self.stem.named_parameters(): + param.requires_grad = False + + if self.fix: + self.train(mode=False) + for name, param in self.named_parameters(): + param.requires_grad = False + + def init_weights(self): + self.logger.info(f"Init Config for {self.model_name}") + self.logger.info(self.init_cfg) + + def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module: + if not isinstance(mode, bool): + raise ValueError("training mode is expected to be boolean") + if self.fix: + super().train(mode=False) + else: + super().train(mode=mode) + if self.fix_layers is not None: + for i, layer_name in enumerate(self.res_layers): + if i in self.fix_layers: + res_layer = getattr(self, layer_name) + res_layer.train(mode=False) + if i == 0: + self.stem.train(mode=False) + return self + + def forward_func(self, x): + x = self.stem(x) + h, w = x.shape[-2:] + x = self.avgpool(x) + outs = [] + if self.model_type == 'vit': + x = torch.cat( + [self.t_class_embedding.to(x.dtype) + + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1 + ) # shape = [*, grid ** 2 + 1, width] + new_pos_embed = resample_abs_pos_embed( + self.t_positional_embedding[None], + [h, w], + num_prefix_tokens=1 + ) + x = x + new_pos_embed.to(x.dtype) + x = self.t_ln_pre_trans(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.t_transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = x[:, 1:] + x = x.permute(0, 2, 1).unflatten(2, (h, w)) # BCHW + for i in range(self.STAGES): + outs.append( + F.interpolate( + x, + scale_factor=2 ** (2 - i), + mode='bilinear', + align_corners=False + ) + ) + else: + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x).contiguous() + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def get_clip_feature(self, backbone_feat): + if self.model_type == 'resnet': + return backbone_feat + elif self.model_type == 'convnext': + return self.norm_pre(backbone_feat) + raise NotImplementedError + + def forward_feat(self, features): + if self.model_type == 'convnext': + batch, num_query, channel = features.shape + features = features.reshape(batch * num_query, channel, 1, 1) + features = self.head(features) + return features.view(batch, num_query, features.shape[-1]) + elif self.model_type == 'resnet': + num_query, channel, seven, seven = features.shape + features = self.head(features) + return features + + def forward(self, x): + if self.fix: + with torch.no_grad(): + outs = self.forward_func(x) + else: + outs = self.forward_func(x) + return outs + + def get_text_model(self): + return OpenCLIPBackboneText( + self.model_name, + init_cfg=self.init_cfg + ) + + +@MODELS.register_module() +class OpenCLIPBackboneText(BaseModule): + def __init__( + self, + model_name: str = '', + init_cfg=None, + ): + assert init_cfg is not None and init_cfg['type'] == 'clip_pretrain', f"{init_cfg['type']} is not supported." + pretrained = init_cfg['checkpoint'] + super().__init__(init_cfg=None) + self.init_cfg = init_cfg + self.logger = MMLogger.get_current_instance() + rank, world_size = get_dist_info() + + if world_size > 1: + if rank == 0: + _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False, + logger=self.logger) + else: + pass + dist.barrier() + + # Get the clip model + clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False, + logger=self.logger) + + # Get the textual model + self.text_tokenizer = open_clip.get_tokenizer(model_name) + self.text_transformer = clip_model.transformer + self.text_token_embedding = clip_model.token_embedding + self.text_pe = clip_model.positional_embedding + self.text_ln_final = clip_model.ln_final + self.text_proj = clip_model.text_projection + + self.register_buffer('text_attn_mask', clip_model.attn_mask) + + self.param_dtype = torch.float32 + self.model_name = model_name + + def init_weights(self): + self.logger.info(f"Init Config for {self.model_name}") + self.logger.info(self.init_cfg) + + # Copied from + # https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L343 + @torch.no_grad() + def forward(self, text): + text_tokens = self.text_tokenizer(text).to(device=self.text_proj.device) + x = self.text_token_embedding(text_tokens).to(self.param_dtype) + x = x + self.text_pe.to(self.param_dtype) + x = x.permute(1, 0, 2) + x = self.text_transformer(x, attn_mask=self.text_attn_mask) + x = x.permute(1, 0, 2) + x = self.text_ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text_tokens.argmax(dim=-1)] @ self.text_proj + return x diff --git a/seg/models/data_preprocessor/__init__.py b/seg/models/data_preprocessor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d717cc380a10fe6fe0ebd5ec9061c33781f0ee6e --- /dev/null +++ b/seg/models/data_preprocessor/__init__.py @@ -0,0 +1,2 @@ +from .vidseg_data_preprocessor import VideoSegDataPreprocessor +from .ovsam_preprocessor import OVSAMDataPreprocessor diff --git a/seg/models/data_preprocessor/ovsam_preprocessor.py b/seg/models/data_preprocessor/ovsam_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..f44149c41dc290fd2a7db5a4f3336a8544864121 --- /dev/null +++ b/seg/models/data_preprocessor/ovsam_preprocessor.py @@ -0,0 +1,405 @@ +import copy +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + +from mmdet.models import DetDataPreprocessor +from mmdet.registry import MODELS +from kornia.contrib import distance_transform +from mmengine.structures import InstanceData + +from seg.models.data_preprocessor import VideoSegDataPreprocessor + + +def get_center_coords(gt_instances, rescale_shape=None, device='cpu'): + if rescale_shape is not None: + masks = gt_instances.masks + masks = masks.rescale(rescale_shape) + else: + masks = gt_instances.masks + masks = masks.to_tensor(dtype=torch.bool, device=device)[:, None] + point_coords = [] + for mask in masks: + mask = mask[None] + n, _, h, w = mask.shape + mask_dt = ( + distance_transform( + (~F.pad(mask, pad=(1, 1, 1, 1), mode='constant', value=0)).float() + )[:, :, 1:-1, 1:-1] + ) + selected_point = torch.tensor([mask_dt.argmax() / w, mask_dt.argmax() % w]).long().flip(0).to( + device) + point_coords.append(selected_point) + if len(point_coords) > 0: + point_coords = torch.stack(point_coords)[:, None] + else: + point_coords = torch.empty((0, 1, 2), dtype=torch.int32).to(device=device) + return point_coords + + +def get_random_points(gt_instances, device='cpu'): + point_coords = [] + for instance_idx in range(len(gt_instances)): + mask = gt_instances.masks.masks[instance_idx] + candidate_indices = torch.tensor(mask, device=device).nonzero() + assert len(candidate_indices) > 0 + selected_point = candidate_indices[torch.randperm( + len(candidate_indices), dtype=torch.int32, device=device)[0]].flip(0) + point_coords.append(selected_point) + if len(point_coords) > 0: + point_coords = torch.stack(point_coords)[:, None] + else: + point_coords = torch.empty((0, 1, 2), dtype=torch.int32).to(device=device) + return point_coords + + +@MODELS.register_module() +class OVSAMDataPreprocessor(DetDataPreprocessor): + def __init__(self, *args, + use_det: bool = False, + use_point: bool = False, + use_center_point: bool = False, + use_point_det: bool = False, + use_center_point_det: bool = False, + use_point_pseudo_box: bool = False, + use_img_center: bool = False, + use_custom_bbox: Optional[Tuple] = None, + use_custom_point: Optional[Tuple] = None, + num_proposals: int = 60, + default_mode: str = 'sam', + **kwargs): + super().__init__(*args, **kwargs) + self.num_proposals = num_proposals + self.use_det = use_det + self.use_point = use_point + self.use_center_point = use_center_point + self.use_point_det = use_point_det + self.use_center_point_det = use_center_point_det + self.use_point_pseudo_box = use_point_pseudo_box + self.use_img_center = use_img_center + self.use_custom_bbox = use_custom_bbox + self.use_custom_point = use_custom_point + self.default_mode = default_mode + + def forward(self, data: dict, training: bool = False) -> dict: + data = super().forward(data, training=training) + inputs, data_samples = data['inputs'], data['data_samples'] + if 'data_tag' in data_samples[0]: + data_tag = data_samples[0].data_tag + for i in range(1, len(data_samples)): + assert data_samples[i].data_tag == data_tag + else: + data_tag = self.default_mode + for i in range(0, len(data_samples)): + data_samples[i].data_tag = data_tag + device = inputs.device + + if data_tag == 'sam_mul': + for data_sample in data_samples: + gt_instances_collected = data_sample.gt_instances_collected + gt_instances = data_sample.gt_instances + masks_list = [] + for idx in range(len(gt_instances_collected)): + gt_ids = gt_instances_collected.sub_instances[idx] + masks_list.append(gt_instances.masks[gt_ids]) + gt_instances = InstanceData( + labels=torch.zeros_like(gt_instances_collected.idx), + masks=masks_list, + point_coords=gt_instances_collected.point_coords, + bp=torch.zeros_like(gt_instances_collected.idx), # all box + ) + # all points + data_sample.gt_instances = gt_instances + del data_sample.gt_instances_collected + elif data_tag == 'sam': + num_proposals = self.num_proposals if training else 10000000 + if self.use_custom_bbox: + for data_sample in data_samples: + img_shape = data_sample.img_shape + data_sample.gt_instances = InstanceData( + bboxes=inputs.new_tensor([[img_shape[1] * self.use_custom_bbox[0], + img_shape[0] * self.use_custom_bbox[1], + img_shape[1] * self.use_custom_bbox[2], + img_shape[0] * self.use_custom_bbox[3]]]) + ) + elif self.use_img_center: + for data_sample in data_samples: + data_sample.gt_instances = InstanceData( + point_coords=inputs.new_tensor([[[data_sample.img_shape[1] / 2, data_sample.img_shape[0] / 2]]]) + ) + elif self.use_custom_point: + for data_sample in data_samples: + data_sample.gt_instances = InstanceData( + point_coords=inputs.new_tensor([[[self.use_custom_point[0], self.use_custom_point[1]]]]) + ) + elif self.use_det: + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + gt_instances = gt_instances[:num_proposals] + if not training: + bboxes = gt_instances.bboxes + scale_factor = bboxes.new_tensor(data_sample.scale_factor).repeat(2) + bboxes = bboxes * scale_factor + gt_instances.bboxes = bboxes + num_ins = len(gt_instances) + bp_indicator = torch.zeros((num_ins,)) + gt_instances.bp = bp_indicator.to(device=device) + data_sample.gt_instances = gt_instances + elif self.use_point_det: + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + if len(gt_instances) < num_proposals: + num_copy = num_proposals // len(gt_instances) + 1 + gt_instances = InstanceData.cat([copy.deepcopy(gt_instances) for _ in range(num_copy)]) + gt_instances = gt_instances[:num_proposals] + if training: + gt_instances.point_coords = get_random_points(gt_instances, device=device) + else: + raise NotImplementedError + num_ins = len(gt_instances) + bp_indicator = torch.arange(2).repeat_interleave((num_ins // 2) + 1)[:num_ins] + gt_instances = gt_instances[torch.randperm(num_ins, device=device)] + gt_instances.bp = bp_indicator.to(device=device) + data_sample.gt_instances = gt_instances + elif self.use_center_point_det: + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + gt_instances = gt_instances[:num_proposals] + if training: + gt_instances.point_coords = get_center_coords(gt_instances, device=device) + else: + gt_instances.point_coords = get_center_coords( + gt_instances, rescale_shape=data_sample.img_shape, device=device + ) + bboxes = gt_instances.bboxes + scale_factor = bboxes.new_tensor(data_sample.scale_factor).repeat(2) + bboxes = bboxes * scale_factor + gt_instances.bboxes = bboxes + data_sample.gt_instances = gt_instances + elif self.use_point: + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + gt_instances = gt_instances[:num_proposals] + if training: + gt_instances.point_coords = get_random_points(gt_instances, device=device) + else: + raise NotImplementedError + data_sample.gt_instances = gt_instances + elif self.use_center_point: + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + gt_instances = gt_instances[:num_proposals] + if training: + gt_instances.point_coords = get_center_coords(gt_instances, device=device) + else: + gt_instances.point_coords = get_center_coords( + gt_instances, rescale_shape=data_sample.img_shape, device=device + ) + data_sample.gt_instances = gt_instances + elif self.use_point_pseudo_box: + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + if training: + if len(gt_instances) < num_proposals: + num_copy = num_proposals // len(gt_instances) + 1 + gt_instances = InstanceData.cat([copy.deepcopy(gt_instances) for _ in range(num_copy)]) + gt_instances = gt_instances[:num_proposals] + points = get_random_points(gt_instances, device=device) + else: + points = get_center_coords( + gt_instances, rescale_shape=data_sample.img_shape, device=device + ) + points = points.squeeze(1) + gt_instances.point_coords = torch.cat([points - 3, points + 3], 1) + gt_instances.bp = torch.zeros_like(gt_instances.labels) # bug to match sam_mul + data_sample.gt_instances = gt_instances + else: + raise NotImplementedError + elif data_tag == 'coco': + pass + elif data_tag == 'img': + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + h, w = data_sample.img_shape + gt_instances.bboxes = torch.tensor( + [[0., 0., h, w]], dtype=torch.float32, device=gt_instances.labels.device + ) + gt_instances.bp = torch.zeros((1,), dtype=torch.int32, device=gt_instances.labels.device) + elif data_tag == 'mosaic_img': + b, three, h, w = inputs.shape + num_img_per_batch = 4 * 4 + assert b % num_img_per_batch == 0 + target_h, target_w = h * 4, w * 4 + new_b = b // num_img_per_batch + result_input = inputs.new_empty(b // num_img_per_batch, three, target_h, target_w) + cnt = 0 + result_data_samples = [] + for id_b in range(new_b): + cur_data_sample = data_samples[cnt] + cur_gt_instances = [] + for id_x in range(4): + for id_y in range(4): + result_input[id_b, :, id_x * h: (id_x + 1) * h, id_y * w: (id_y + 1) * w] = inputs[cnt] + img_gt_instances = data_samples[cnt].gt_instances + img_gt_instances.bboxes += img_gt_instances.bboxes.new_tensor([ + id_x * h, id_y * w, id_x * h, id_y * w + ]) + cur_gt_instances.append(img_gt_instances) + cnt += 1 + cur_gt_instances = InstanceData.cat(cur_gt_instances) + cur_data_sample.gt_instances = cur_gt_instances + result_data_samples.append(cur_data_sample) + + inputs = result_input + data_samples = result_data_samples + else: + raise NotImplementedError + return dict(inputs=inputs, data_samples=data_samples) + + +@MODELS.register_module() +class OVSAMVideoSegDataPreprocessor(VideoSegDataPreprocessor): + def __init__(self, *args, + use_det: bool = False, + use_point: bool = False, + use_center_point: bool = False, + use_point_det: bool = False, + use_center_point_det: bool = False, + use_point_pseudo_box: bool = False, + num_proposals: int = 60, + **kwargs): + super().__init__(*args, **kwargs) + self.num_proposals = num_proposals + self.use_det = use_det + self.use_point = use_point + self.use_center_point = use_center_point + self.use_point_det = use_point_det + self.use_center_point_det = use_center_point_det + self.use_point_pseudo_box = use_point_pseudo_box + + def forward(self, data: dict, training: bool = False) -> dict: + data = super().forward(data, training=training) + inputs, data_samples = data['inputs'], data['data_samples'] + if 'data_tag' in data_samples[0]: + data_tag = data_samples[0].data_tag + for i in range(1, len(data_samples)): + assert data_samples[i].data_tag == data_tag + else: + data_tag = 'sam' + for i in range(0, len(data_samples)): + data_samples[i].data_tag = data_tag + device = inputs.device + + if data_tag == 'sam_mul': + for data_sample in data_samples: + gt_instances_collected = data_sample.gt_instances_collected + gt_instances = data_sample.gt_instances + masks_list = [] + for idx in range(len(gt_instances_collected)): + gt_ids = gt_instances_collected.sub_instances[idx] + masks_list.append(gt_instances.masks[gt_ids]) + gt_instances = InstanceData( + labels=torch.zeros_like(gt_instances_collected.idx), + masks=masks_list, + point_coords=gt_instances_collected.point_coords, + bp=torch.zeros_like(gt_instances_collected.idx), # all box + ) + # all points + data_sample.gt_instances = gt_instances + del data_sample.gt_instances_collected + elif data_tag == 'sam': + num_proposals = self.num_proposals if training else 10000000 + if self.use_det: + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + gt_instances = gt_instances[:num_proposals] + if not training: + bboxes = gt_instances.bboxes + scale_factor = bboxes.new_tensor(data_sample.scale_factor).repeat(2) + bboxes = bboxes * scale_factor + gt_instances.bboxes = bboxes + data_sample.gt_instances = gt_instances + elif self.use_point_det: + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + if len(gt_instances) < num_proposals: + num_copy = num_proposals // len(gt_instances) + 1 + gt_instances = InstanceData.cat([copy.deepcopy(gt_instances) for _ in range(num_copy)]) + gt_instances = gt_instances[:num_proposals] + if training: + gt_instances.point_coords = get_random_points(gt_instances, device=device) + else: + raise NotImplementedError + num_ins = len(gt_instances) + bp_indicator = torch.arange(2).repeat_interleave((num_ins // 2) + 1)[:num_ins] + gt_instances = gt_instances[torch.randperm(num_ins, device=device)] + gt_instances.bp = bp_indicator.to(device=device) + data_sample.gt_instances = gt_instances + elif self.use_center_point_det: + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + gt_instances = gt_instances[:num_proposals] + if training: + gt_instances.point_coords = get_center_coords(gt_instances, device=device) + else: + gt_instances.point_coords = get_center_coords( + gt_instances, rescale_shape=data_sample.img_shape, device=device + ) + bboxes = gt_instances.bboxes + scale_factor = bboxes.new_tensor(data_sample.scale_factor).repeat(2) + bboxes = bboxes * scale_factor + gt_instances.bboxes = bboxes + data_sample.gt_instances = gt_instances + elif self.use_point: + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + gt_instances = gt_instances[:num_proposals] + if training: + gt_instances.point_coords = get_random_points(gt_instances, device=device) + else: + raise NotImplementedError + data_sample.gt_instances = gt_instances + elif self.use_center_point: + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + gt_instances = gt_instances[:num_proposals] + if training: + gt_instances.point_coords = get_center_coords(gt_instances, device=device) + else: + gt_instances.point_coords = get_center_coords( + gt_instances, rescale_shape=data_sample.img_shape, device=device + ) + data_sample.gt_instances = gt_instances + elif self.use_point_pseudo_box: + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + if training: + if len(gt_instances) < num_proposals: + num_copy = num_proposals // len(gt_instances) + 1 + gt_instances = InstanceData.cat([copy.deepcopy(gt_instances) for _ in range(num_copy)]) + gt_instances = gt_instances[:num_proposals] + points = get_random_points(gt_instances, device=device) + else: + points = get_center_coords( + gt_instances, rescale_shape=data_sample.img_shape, device=device + ) + points = points.squeeze(1) + gt_instances.point_coords = torch.cat([points - 3, points + 3], 1) + gt_instances.bp = torch.zeros_like(gt_instances.labels) # bug to match sam_mul + data_sample.gt_instances = gt_instances + else: + raise NotImplementedError + elif data_tag == 'coco': + pass + elif data_tag == 'img': + for data_sample in data_samples: + gt_instances = data_sample.gt_instances + h, w = data_sample.img_shape + gt_instances.bboxes = torch.tensor( + [[0., 0., h, w]], dtype=torch.float32, device=gt_instances.labels.device + ) + gt_instances.bp = torch.zeros((1,), dtype=torch.int32, device=gt_instances.labels.device) + else: + raise NotImplementedError + return dict(inputs=inputs, data_samples=data_samples) diff --git a/seg/models/data_preprocessor/vidseg_data_preprocessor.py b/seg/models/data_preprocessor/vidseg_data_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..d05d9b363a2ca37c2f02e20ea3f5280556ef2551 --- /dev/null +++ b/seg/models/data_preprocessor/vidseg_data_preprocessor.py @@ -0,0 +1,323 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Optional, Sequence, Union + +import numpy as np +import torch +from mmengine.model import BaseDataPreprocessor +from torch import Tensor +import torch.nn.functional as F +from mmdet.structures.bbox import BaseBoxes +from mmengine.model.utils import stack_batch + +from mmdet.models.utils.misc import samplelist_boxtype2tensor, unfold_wo_center +from mmdet.registry import MODELS +from mmdet.structures import TrackDataSample, TrackSampleList +from mmdet.structures.mask import BitmapMasks +from mmdet.models.data_preprocessors import DetDataPreprocessor +from mmengine.structures import PixelData + +try: + import skimage +except ImportError: + skimage = None + + +@MODELS.register_module() +class VideoSegDataPreprocessor(DetDataPreprocessor): + """Image pre-processor for tracking tasks. + + Accepts the data sampled by the dataloader, and preprocesses + it into the format of the model input. ``TrackDataPreprocessor`` + provides the tracking data pre-processing as follows: + + - Collate and move data to the target device. + - Pad inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor`` + - Stack inputs to inputs. + - Convert inputs from bgr to rgb if the shape of input is (1, 3, H, W). + - Normalize image with defined std and mean. + - Do batch augmentations during training. + - Record the information of ``batch_input_shape`` and ``pad_shape``. + + Args: + mean (Sequence[Number], optional): The pixel mean of R, G, B + channels. Defaults to None. + std (Sequence[Number], optional): The pixel standard deviation of + R, G, B channels. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (Number): The padded pixel value. Defaults to 0. + pad_mask (bool): Whether to pad instance masks. Defaults to False. + mask_pad_value (int): The padded pixel value for instance masks. + Defaults to 0. + bgr_to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + rgb_to_bgr (bool): whether to convert image from RGB to RGB. + Defaults to False. + use_det_processor: (bool): whether to use DetDataPreprocessor + in training phrase. This is mainly for some tracking models + fed into one image rather than a group of image in training. + Defaults to False. + . boxtype2tensor (bool): Whether to convert the ``BaseBoxes`` type of + bboxes data to ``Tensor`` type. Defaults to True. + batch_augments (list[dict], optional): Batch-level augmentations + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + use_det_processor: bool = False, + **kwargs): + super().__init__(mean=mean, std=std, **kwargs) + self.use_det_processor = use_det_processor + if mean is not None and not self.use_det_processor: + # overwrite the ``register_bufffer`` in ``ImgDataPreprocessor`` + # since the shape of ``mean`` and ``std`` in tracking tasks must be + # (T, C, H, W), which T is the temporal length of the video. + self.register_buffer('mean', + torch.tensor(mean).view(1, -1, 1, 1), False) + self.register_buffer('std', + torch.tensor(std).view(1, -1, 1, 1), False) + + + def forward(self, data: dict, training: bool = False) -> Dict: + """Perform normalization、padding and bgr2rgb conversion based on + ``TrackDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + Tuple[Dict[str, List[torch.Tensor]], OptSampleList]: Data in the + same format as the model input. + """ + if not isinstance(data['data_samples'][0], TrackDataSample): + use_det = True + else: + use_det = False + if self.use_det_processor and training or use_det: + batch_pad_shape = self._get_pad_shape(data) + else: + batch_pad_shape = self._get_track_pad_shape(data) + + data = self.cast_data(data) + imgs, data_samples = data['inputs'], data['data_samples'] + + if self.use_det_processor and training or use_det: + assert imgs[0].dim() == 3, \ + 'Only support the 3 dims when use detpreprocessor in training' + if self._channel_conversion: + imgs = [_img[[2, 1, 0], ...] for _img in imgs] + # Convert to `float` + imgs = [_img.float() for _img in imgs] + if self._enable_normalize: + imgs = [(_img - self.mean.squeeze(0)) / self.std.squeeze(0) for _img in imgs] + inputs = stack_batch(imgs, self.pad_size_divisor, self.pad_value) + else: + assert imgs[0].dim() == 4, \ + 'Only support the 4 dims when use trackprocessor in training' + # The shape of imgs[0] is (T, C, H, W). + channel = imgs[0].size(1) + if self._channel_conversion and channel == 3: + imgs = [_img[:, [2, 1, 0], ...] for _img in imgs] + # change to `float` + imgs = [_img.float() for _img in imgs] + if self._enable_normalize: + imgs = [(_img - self.mean) / self.std for _img in imgs] + inputs = stack_track_batch(imgs, self.pad_size_divisor, + self.pad_value) + + if data_samples is not None: + # NOTE the batched image size information may be useful, e.g. + # in DETR, this is needed for the construction of masks, which is + # then used for the transformer_head. + batch_input_shape = tuple(inputs.size()[-2:]) + if self.use_det_processor and training or use_det: + for data_sample, pad_shape in zip(data_samples, + batch_pad_shape): + data_sample.set_metainfo({ + 'batch_input_shape': batch_input_shape, + 'pad_shape': pad_shape + }) + if self.boxtype2tensor: + samplelist_boxtype2tensor(data_samples) + + if self.pad_mask and training: + self.pad_gt_masks(data_samples) + + if self.pad_seg and training: + self.pad_gt_sem_seg(data_samples) + else: + for track_data_sample, pad_shapes in zip( + data_samples, batch_pad_shape): + for i in range(len(track_data_sample)): + det_data_sample = track_data_sample[i] + det_data_sample.set_metainfo({ + 'batch_input_shape': batch_input_shape, + 'pad_shape': pad_shapes[i] + }) + + if self.boxtype2tensor: + tracking_samplelist_boxtype2tensor(data_samples) + + if self.pad_mask and training: + self.pad_track_gt_masks(data_samples) + + if self.pad_seg and training: + self.pad_track_gt_sem_seg(data_samples) + + if training and self.batch_augments is not None: + for batch_aug in self.batch_augments: + if self.use_det_processor and training or use_det: + inputs, data_samples = batch_aug(inputs, data_samples) + else: + # For video segmentation, the batch augmentation are conducted + # on the batch dimension only, which means it will be run several + # times given the number of frames. + final_inputs = [] + for frame_id in range(inputs.size(1)): + det_data_samples = [ + track_data_sample[frame_id] + for track_data_sample in data_samples + ] + aug_inputs, aug_det_samples = batch_aug( + inputs[:, frame_id], det_data_samples) + final_inputs.append(aug_inputs.unsqueeze(1)) + for track_data_sample, det_sample in zip( + data_samples, aug_det_samples): + track_data_sample.video_data_samples[frame_id] = det_sample + inputs = torch.cat(final_inputs, dim=1) + + # Note: inputs may contain large number of frames, so we must make + # sure that the mmeory is contiguous for stable forward + inputs = inputs.contiguous() + return dict(inputs=inputs, data_samples=data_samples) + + def _get_track_pad_shape(self, data: dict) -> Dict[str, List]: + """Get the pad_shape of each image based on data and pad_size_divisor. + + Args: + data (dict): Data sampled from dataloader. + + Returns: + Dict[str, List]: The shape of padding. + """ + batch_pad_shape = dict() + batch_pad_shape = [] + for imgs in data['inputs']: + # The sequence images in one sample among a batch have the same + # original shape + pad_h = int(np.ceil(imgs.shape[-2] / + self.pad_size_divisor)) * self.pad_size_divisor + pad_w = int(np.ceil(imgs.shape[-1] / + self.pad_size_divisor)) * self.pad_size_divisor + pad_shapes = [(pad_h, pad_w)] * imgs.size(0) + batch_pad_shape.append(pad_shapes) + return batch_pad_shape + + def pad_track_gt_masks(self, + data_samples: Sequence[TrackDataSample]) -> None: + """Pad gt_masks to shape of batch_input_shape.""" + if 'masks' in data_samples[0][0].get('gt_instances', None): + for track_data_sample in data_samples: + for i in range(len(track_data_sample)): + det_data_sample = track_data_sample[i] + masks = det_data_sample.gt_instances.masks + # TODO: whether to use BitmapMasks + assert isinstance(masks, BitmapMasks) + batch_input_shape = det_data_sample.batch_input_shape + det_data_sample.gt_instances.masks = masks.pad( + batch_input_shape, pad_val=self.mask_pad_value) + + def pad_track_gt_sem_seg(self, + data_samples: Sequence[TrackDataSample]) -> None: + """Pad gt_sem_seg to shape of batch_input_shape.""" + if 'gt_sem_seg' in data_samples[0][0]: + for track_data_sample in data_samples: + for i in range(len(track_data_sample)): + det_data_sample = track_data_sample[i] + gt_sem_seg = det_data_sample.gt_sem_seg.sem_seg + h, w = gt_sem_seg.shape[-2:] + pad_h, pad_w = det_data_sample.batch_input_shape + gt_sem_seg = F.pad( + gt_sem_seg, + pad=(0, max(pad_w - w, 0), 0, max(pad_h - h, 0)), + mode='constant', + value=self.seg_pad_value) + det_data_sample.gt_sem_seg = PixelData(sem_seg=gt_sem_seg) + + +def stack_track_batch(tensors: List[torch.Tensor], + pad_size_divisor: int = 0, + pad_value: Union[int, float] = 0) -> torch.Tensor: + """Stack multiple tensors to form a batch and pad the images to the max + shape use the right bottom padding mode in these images. If + ``pad_size_divisor > 0``, add padding to ensure the common height and width + is divisible by ``pad_size_divisor``. The difference between this function + and ``stack_batch`` in MMEngine is that this function can process batch + sequence images with shape (N, T, C, H, W). + + Args: + tensors (List[Tensor]): The input multiple tensors. each is a + TCHW 4D-tensor. T denotes the number of key/reference frames. + pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding + to ensure the common height and width is divisible by + ``pad_size_divisor``. This depends on the model, and many + models need a divisibility of 32. Defaults to 0 + pad_value (int, float): The padding value. Defaults to 0 + + Returns: + Tensor: The NTCHW 5D-tensor. N denotes the batch size. + """ + assert isinstance(tensors, list), \ + f'Expected input type to be list, but got {type(tensors)}' + assert len(set([tensor.ndim for tensor in tensors])) == 1, \ + f'Expected the dimensions of all tensors must be the same, ' \ + f'but got {[tensor.ndim for tensor in tensors]}' + assert tensors[0].ndim == 4, f'Expected tensor dimension to be 4, ' \ + f'but got {tensors[0].ndim}' + assert len(set([tensor.shape[0] for tensor in tensors])) == 1, \ + f'Expected the channels of all tensors must be the same, ' \ + f'but got {[tensor.shape[0] for tensor in tensors]}' + + tensor_sizes = [(tensor.shape[-2], tensor.shape[-1]) for tensor in tensors] + max_size = np.stack(tensor_sizes).max(0) + + if pad_size_divisor > 1: + # the last two dims are H,W, both subject to divisibility requirement + max_size = ( + max_size + + (pad_size_divisor - 1)) // pad_size_divisor * pad_size_divisor + + padded_samples = [] + for tensor in tensors: + padding_size = [ + 0, max_size[-1] - tensor.shape[-1], 0, + max_size[-2] - tensor.shape[-2] + ] + if sum(padding_size) == 0: + padded_samples.append(tensor) + else: + padded_samples.append(F.pad(tensor, padding_size, value=pad_value)) + + return torch.stack(padded_samples, dim=0) + + +def tracking_samplelist_boxtype2tensor(batch_track_samples: TrackSampleList) -> None: + for track_data_sample in batch_track_samples: + for data_samples in track_data_sample.video_data_samples: + if 'gt_instances' in data_samples: + bboxes = data_samples.gt_instances.get('bboxes', None) + if isinstance(bboxes, BaseBoxes): + data_samples.gt_instances.bboxes = bboxes.tensor + if 'pred_instances' in data_samples: + bboxes = data_samples.pred_instances.get('bboxes', None) + if isinstance(bboxes, BaseBoxes): + data_samples.pred_instances.bboxes = bboxes.tensor + if 'ignored_instances' in data_samples: + bboxes = data_samples.ignored_instances.get('bboxes', None) + if isinstance(bboxes, BaseBoxes): + data_samples.ignored_instances.bboxes = bboxes.tensor diff --git a/seg/models/detectors/__init__.py b/seg/models/detectors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9aad4dfce834cf306017910bfdca1ec2cf3d313 --- /dev/null +++ b/seg/models/detectors/__init__.py @@ -0,0 +1,2 @@ +from .mask2former_vid import Mask2formerVideo +from .mask2former_vid_minvis import Mask2formerVideoMinVIS diff --git a/seg/models/detectors/mask2former_vid.py b/seg/models/detectors/mask2former_vid.py new file mode 100644 index 0000000000000000000000000000000000000000..cce03e030571e180cccfc0dfd08d6035867b0984 --- /dev/null +++ b/seg/models/detectors/mask2former_vid.py @@ -0,0 +1,318 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import torch +from mmengine.structures import InstanceData +from torch import Tensor +import torch.nn.functional as F + +from mmdet.registry import MODELS +from mmdet.structures import SampleList, OptSampleList, TrackDataSample +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from mmdet.models.detectors.single_stage import SingleStageDetector + +from seg.models.utils import mask_pool + + +@MODELS.register_module() +class Mask2formerVideo(SingleStageDetector): + r"""Implementation of `Per-Pixel Classification is + NOT All You Need for Semantic Segmentation + `_.""" + OVERLAPPING = None + + def __init__(self, + backbone: ConfigType, + neck: OptConfigType = None, + panoptic_head: OptConfigType = None, + panoptic_fusion_head: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + inference_sam: bool = False, + init_cfg: OptMultiConfig = None + ): + super(SingleStageDetector, self).__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + + panoptic_head_ = panoptic_head.deepcopy() + panoptic_head_.update(train_cfg=train_cfg) + panoptic_head_.update(test_cfg=test_cfg) + self.panoptic_head = MODELS.build(panoptic_head_) + + panoptic_fusion_head_ = panoptic_fusion_head.deepcopy() + panoptic_fusion_head_.update(test_cfg=test_cfg) + self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_) + + self.num_things_classes = self.panoptic_head.num_things_classes + self.num_stuff_classes = self.panoptic_head.num_stuff_classes + self.num_classes = self.panoptic_head.num_classes + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + self.alpha = 0.4 + self.beta = 0.8 + + self.inference_sam = inference_sam + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Dict[str, Tensor]: + """ + Args: + batch_inputs (Tensor): Input images of shape (N, C, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + if isinstance(batch_data_samples[0], TrackDataSample): + bs, num_frames, three, h, w = batch_inputs.shape + assert three == 3, "Only supporting images with 3 channels." + + x = batch_inputs.reshape((bs * num_frames, three, h, w)) + x = self.extract_feat(x) + else: + x = self.extract_feat(batch_inputs) + losses = self.panoptic_head.loss(x, batch_data_samples) + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances' and `pred_panoptic_seg`. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + + And the ``pred_panoptic_seg`` contains the following key + + - sem_seg (Tensor): panoptic segmentation mask, has a + shape (1, h, w). + """ + if isinstance(batch_data_samples[0], TrackDataSample): + bs, num_frames, three, h, w = batch_inputs.shape + assert three == 3, "Only supporting images with 3 channels." + x = batch_inputs.reshape((bs * num_frames, three, h, w)) + feats = self.extract_feat(x) + else: + num_frames = 0 + bs = batch_inputs.shape[0] + feats = self.extract_feat(batch_inputs) + + # in case no queries are provided for prompt. + if self.inference_sam and len(batch_data_samples[0].gt_instances) == 0: + for idx, data_sample in enumerate(batch_data_samples): + results = InstanceData() + data_sample.pred_instances = results + return batch_data_samples + + mask_cls_results, mask_pred_results, iou_results = self.panoptic_head.predict(feats, batch_data_samples) + + if self.OVERLAPPING is not None: + assert len(self.OVERLAPPING) == self.num_classes + mask_cls_results = self.open_voc_inference(feats, mask_cls_results, mask_pred_results) + + if self.inference_sam: + for idx, data_sample in enumerate(batch_data_samples): + results = InstanceData() + mask = mask_pred_results[idx] + img_height, img_width = data_sample.metainfo['img_shape'][:2] + mask = mask[:, :img_height, :img_width] + ori_height, ori_width = data_sample.metainfo['ori_shape'][:2] + mask = F.interpolate( + mask[:, None], + size=(ori_height, ori_width), + mode='bilinear', + align_corners=False)[:, 0] + results.masks = mask.sigmoid() > 0.5 + data_sample.pred_instances = results + return batch_data_samples + + if num_frames > 0: + for frame_id in range(num_frames): + results_list_img = self.panoptic_fusion_head.predict( + mask_cls_results, + mask_pred_results[:, :, frame_id], + [batch_data_samples[idx][frame_id] for idx in range(bs)], + rescale=rescale + ) + _ = self.add_track_pred_to_datasample( + [batch_data_samples[idx][frame_id] for idx in range(bs)], results_list_img + ) + results = batch_data_samples + else: + results_list = self.panoptic_fusion_head.predict( + mask_cls_results, + mask_pred_results, + batch_data_samples, + iou_results=iou_results, + rescale=rescale + ) + results = self.add_pred_to_datasample(batch_data_samples, results_list) + + return results + + def add_pred_to_datasample(self, data_samples: SampleList, + results_list: List[dict]) -> SampleList: + """Add predictions to `DetDataSample`. + + Args: + data_samples (list[:obj:`DetDataSample`], optional): A batch of + data samples that contain annotations and predictions. + results_list (List[dict]): Instance segmentation, segmantic + segmentation and panoptic segmentation results. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances' and `pred_panoptic_seg`. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + + And the ``pred_panoptic_seg`` contains the following key + + - sem_seg (Tensor): panoptic segmentation mask, has a + shape (1, h, w). + """ + for data_sample, pred_results in zip(data_samples, results_list): + if 'pan_results' in pred_results: + data_sample.pred_panoptic_seg = pred_results['pan_results'] + + if 'ins_results' in pred_results: + data_sample.pred_instances = pred_results['ins_results'] + + assert 'sem_results' not in pred_results + + return data_samples + + def add_track_pred_to_datasample(self, data_samples: SampleList, results_list: List[dict]) -> SampleList: + for data_sample, pred_results in zip(data_samples, results_list): + if 'pan_results' in pred_results: + assert self.num_stuff_classes > 0 + pred_results['pan_results'].sem_seg = pred_results['pan_results'].sem_seg.cpu() + data_sample.pred_track_panoptic_seg = pred_results['pan_results'] + + if 'ins_results' in pred_results: + bboxes = pred_results['ins_results']['bboxes'] + labels = pred_results['ins_results']['labels'] + track_ids = torch.arange(len(bboxes), dtype=labels.dtype, device=bboxes.device) + 1 + pred_results['ins_results']['instances_id'] = track_ids + data_sample.pred_track_instances = pred_results['ins_results'] + + if 'pro_results' in pred_results: + data_sample.pred_track_proposal = pred_results['pro_results'] + + assert 'sem_results' not in pred_results + + return data_samples + + def _forward( + self, + batch_inputs: Tensor, + batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: + """Network forward process. Usually includes backbone, neck and head + forward without any post-processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + tuple[List[Tensor]]: A tuple of features from ``panoptic_head`` + forward. + """ + if isinstance(batch_data_samples[0], TrackDataSample): + bs, num_frames, three, h, w = batch_inputs.shape + assert three == 3, "Only supporting images with 3 channels." + + x = batch_inputs.reshape((bs * num_frames, three, h, w)) + feats = self.extract_feat(x) + else: + feats = self.extract_feat(batch_inputs) + results = self.panoptic_head.forward(feats, batch_data_samples) + return results + + def open_voc_inference(self, feats, mask_cls_results, mask_pred_results): + if len(mask_pred_results.shape) == 5: + batch_size = mask_cls_results.shape[0] + num_frames = mask_pred_results.shape[2] + mask_pred_results = mask_pred_results.permute(0, 2, 1, 3, 4).flatten(0, 1) + else: + batch_size = mask_cls_results.shape[0] + num_frames = 0 + clip_feat = self.backbone.get_clip_feature(feats[-1]) + clip_feat_mask = F.interpolate( + mask_pred_results, + size=clip_feat.shape[-2:], + mode='bilinear', + align_corners=False + ) + if num_frames > 0: + clip_feat_mask = clip_feat_mask.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3) + clip_feat = clip_feat.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3) + instance_feat = mask_pool(clip_feat, clip_feat_mask) + instance_feat = self.backbone.forward_feat(instance_feat) + clip_logit = self.panoptic_head.forward_logit(instance_feat) + clip_logit = clip_logit[..., :-1] + query_logit = mask_cls_results[..., :-1] + + clip_logit = clip_logit.softmax(-1) + query_logit = query_logit.softmax(-1) + overlapping_mask = torch.tensor(self.OVERLAPPING, dtype=torch.float32, device=clip_logit.device) + + valid_masking = ((clip_feat_mask > 0).to(dtype=torch.float32).flatten(-2).sum(-1) > 0).to( + torch.float32)[..., None] + alpha = torch.ones_like(clip_logit) * self.alpha * valid_masking + beta = torch.ones_like(clip_logit) * self.beta * valid_masking + + cls_logits_seen = ( + (query_logit ** (1 - alpha) * clip_logit ** alpha).log() + * overlapping_mask + ) + cls_logits_unseen = ( + (query_logit ** (1 - beta) * clip_logit ** beta).log() + * (1 - overlapping_mask) + ) + cls_results = cls_logits_seen + cls_logits_unseen + is_void_prob = F.softmax(mask_cls_results, dim=-1)[..., -1:] + mask_cls_results = torch.cat([ + cls_results.softmax(-1) * (1.0 - is_void_prob), is_void_prob], dim=-1) + mask_cls_results = torch.log(mask_cls_results + 1e-8) + return mask_cls_results diff --git a/seg/models/detectors/mask2former_vid_minvis.py b/seg/models/detectors/mask2former_vid_minvis.py new file mode 100644 index 0000000000000000000000000000000000000000..a9cf9d76ee12e33116b31549a6f361d672fbcd0a --- /dev/null +++ b/seg/models/detectors/mask2former_vid_minvis.py @@ -0,0 +1,299 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +import torch +from scipy.optimize import linear_sum_assignment +from torch import Tensor +import torch.nn.functional as F + +from mmdet.registry import MODELS +from mmdet.structures import SampleList, TrackDataSample + +from seg.models.detectors import Mask2formerVideo +from seg.models.utils import mask_pool + +BACKBONE_BATCH = 50 + + +def video_split(total, tube_size, overlap=0): + assert tube_size > overlap + total -= overlap + tube_size -= overlap + + if total % tube_size == 0: + splits = total // tube_size + else: + splits = (total // tube_size) + 1 + + ind_list = [] + for i in range(splits): + ind_list.append((i + 1) * tube_size) + + diff = ind_list[-1] - total + + # currently only supports diff < splits + if diff < splits: + for i in range(diff): + ind_list[splits - 1 - i] -= diff - i + else: + ind_list[splits - 1] -= diff + assert ind_list[splits - 1] > 0 + print("Warning: {} / {}".format(total, tube_size)) + + for idx in range(len(ind_list)): + ind_list[idx] += overlap + + return ind_list + + +def match_from_embeds(tgt_embds, cur_embds): + cur_embds = cur_embds / cur_embds.norm(dim=-1, keepdim=True) + tgt_embds = tgt_embds / tgt_embds.norm(dim=-1, keepdim=True) + cos_sim = torch.bmm(cur_embds, tgt_embds.transpose(1, 2)) + + cost_embd = 1 - cos_sim + + C = 1.0 * cost_embd + C = C.cpu() + + indices = [] + for i in range(len(cur_embds)): + indice = linear_sum_assignment(C[i].transpose(0, 1)) # target x current + indice = indice[1] # permutation that makes current aligns to target + indices.append(indice) + + return indices + + +@MODELS.register_module() +class Mask2formerVideoMinVIS(Mask2formerVideo): + r"""Implementation of `Per-Pixel Classification is + NOT All You Need for Semantic Segmentation + `_.""" + OVERLAPPING = None + + def __init__(self, + *args, + clip_size=6, + clip_size_small=3, + whole_clip_thr=0, + small_clip_thr=12, + overlap=0, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.clip_size = clip_size + self.clip_size_small = clip_size_small + self.overlap = overlap + self.whole_clip_thr = whole_clip_thr + self.small_clip_thr = small_clip_thr + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs with shape (N, C, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the + input images. Each DetDataSample usually contain + 'pred_instances' and `pred_panoptic_seg`. And the + ``pred_instances`` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + + And the ``pred_panoptic_seg`` contains the following key + + - sem_seg (Tensor): panoptic segmentation mask, has a + shape (1, h, w). + """ + assert isinstance(batch_data_samples[0], TrackDataSample) + + bs, num_frames, three, h, w = batch_inputs.shape + assert three == 3, "Only supporting images with 3 channels." + if num_frames <= self.whole_clip_thr: + return super().predict(batch_inputs, batch_data_samples, rescale) + + device = batch_inputs.device + + if num_frames > self.small_clip_thr: + tube_inds = video_split(num_frames, self.clip_size, self.overlap) + else: + tube_inds = video_split(num_frames, self.clip_size_small, self.overlap) + if num_frames > BACKBONE_BATCH: + feat_bins = [[], [], [], []] + num_clip = num_frames // BACKBONE_BATCH + 1 + step_size = num_frames // num_clip + 1 + for i in range(num_clip): + start = i * step_size + end = min(num_frames, (i + 1) * step_size) + inputs = batch_inputs[:, start:end].reshape( + (bs * (end - start), three, h, w)) + _feats = self.extract_feat(inputs) + assert len(_feats) == 4 + for idx, item in enumerate(_feats): + feat_bins[idx].append(item.to('cpu')) + feats = [] + for item in feat_bins: + feat = torch.cat(item, dim=0) + assert feat.size(0) == bs * num_frames, "{} vs {}".format(feat.size(0), bs * num_frames) + feats.append(feat) + else: + x = batch_inputs.reshape((bs * num_frames, three, h, w)) + feats = self.extract_feat(x) + assert len(feats[0]) == bs * num_frames + + del batch_inputs + + ind_pre = 0 + cls_list = [] + mask_list = [] + query_list = [] + iou_list = [] + flag = False + for ind in tube_inds: + tube_feats = [itm[ind_pre:ind].to(device=device) for itm in feats] + tube_data_samples = [TrackDataSample(video_data_samples=itm[ind_pre:ind]) for itm in batch_data_samples] + _mask_cls_results, _mask_pred_results, _query_feat, _iou_results = \ + self.panoptic_head.predict(tube_feats, tube_data_samples, return_query=True) + cls_list.append(_mask_cls_results) + if not flag: + mask_list.append(_mask_pred_results.cpu()) + flag = True + else: + mask_list.append(_mask_pred_results[:, self.overlap:].cpu()) + query_list.append(_query_feat.cpu()) + iou_list.append(_iou_results) + + ind_pre = ind + ind_pre -= self.overlap + + num_tubes = len(tube_inds) + + out_cls = [cls_list[0]] + out_mask = [mask_list[0]] + out_embed = [query_list[0]] + ious = [iou_list[0]] + + for i in range(1, num_tubes): + indices = match_from_embeds(out_embed[-1], query_list[i]) + indices = indices[0] # since bs == 1 + + out_cls.append(cls_list[i][:, indices]) + out_mask.append(mask_list[i][:, indices]) + out_embed.append(query_list[i][:, indices]) + ious.append(iou_list[i][:, indices]) + + del mask_list + del out_embed + mask_cls_results = sum(out_cls) / num_tubes + mask_pred_results = torch.cat(out_mask, dim=2) + iou_results = sum(ious) / num_tubes + + if self.OVERLAPPING is not None: + assert len(self.OVERLAPPING) == self.num_classes + mask_cls_results = self.open_voc_inference(feats, mask_cls_results, mask_pred_results) + + del feats + mask_cls_results = mask_cls_results.to(device='cpu') + iou_results = iou_results.to(device='cpu') + + id_assigner = [{} for _ in range(bs)] + + for frame_id in range(num_frames): + results_list_img = self.panoptic_fusion_head.predict( + mask_cls_results, + mask_pred_results[:, :, frame_id], + [batch_data_samples[idx][frame_id] for idx in range(bs)], + iou_results=iou_results, + rescale=rescale + ) + if frame_id == 0 and 'pro_results' in results_list_img[0]: + for batch_id in range(bs): + mask = results_list_img[batch_id]['pro_results'].to(dtype=torch.int32) + mask_gt = torch.tensor(batch_data_samples[batch_id][frame_id].gt_instances.masks.masks, dtype=torch.int32) + a, b = mask.flatten(1), mask_gt.flatten(1) + intersection = torch.einsum('nc,mc->nm', a, b) + union = (a[:, None] + b[None]).clamp(min=0, max=1).sum(-1) + iou_cost = intersection / union + a_indices, b_indices = linear_sum_assignment(-iou_cost.numpy()) + + for a_ind, b_ind in zip(a_indices, b_indices): + id_assigner[batch_id][a_ind] = batch_data_samples[batch_id][frame_id].gt_instances.instances_ids[b_ind].item() + + if 'pro_results' in results_list_img[0]: + h, w = results_list_img[batch_id]['pro_results'].shape[-2:] + seg_map = torch.full((h, w), 0, dtype=torch.int32, device='cpu') + for ind in id_assigner[batch_id]: + seg_map[results_list_img[batch_id]['pro_results'][ind]] = id_assigner[batch_id][ind] + results_list_img[batch_id]['pro_results'] = seg_map.cpu().numpy() + + _ = self.add_track_pred_to_datasample( + [batch_data_samples[idx][frame_id] for idx in range(bs)], results_list_img + ) + results = batch_data_samples + + return results + + def open_voc_inference(self, feats, mask_cls_results, mask_pred_results): + if len(mask_pred_results.shape) == 5: + batch_size = mask_cls_results.shape[0] + num_frames = mask_pred_results.shape[2] + mask_pred_results = mask_pred_results.permute(0, 2, 1, 3, 4).flatten(0, 1) + else: + batch_size = mask_cls_results.shape[0] + num_frames = 0 + clip_feat = self.backbone.get_clip_feature(feats[-1]).to(device=mask_cls_results.device) + clip_feat_mask = F.interpolate( + mask_pred_results, + size=clip_feat.shape[-2:], + mode='bilinear', + align_corners=False + ).to(device=mask_cls_results.device) + if num_frames > 0: + clip_feat_mask = clip_feat_mask.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3) + clip_feat = clip_feat.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3) + instance_feat = mask_pool(clip_feat, clip_feat_mask) + instance_feat = self.backbone.forward_feat(instance_feat) + clip_logit = self.panoptic_head.forward_logit(instance_feat) + clip_logit = clip_logit[..., :-1] + query_logit = mask_cls_results[..., :-1] + + clip_logit = clip_logit.softmax(-1) + query_logit = query_logit.softmax(-1) + overlapping_mask = torch.tensor(self.OVERLAPPING, dtype=torch.float32, device=clip_logit.device) + + valid_masking = ((clip_feat_mask > 0).to(dtype=torch.float32).flatten(-2).sum(-1) > 0).to( + torch.float32)[..., None] + alpha = torch.ones_like(clip_logit) * self.alpha * valid_masking + beta = torch.ones_like(clip_logit) * self.beta * valid_masking + + cls_logits_seen = ( + (query_logit ** (1 - alpha) * clip_logit ** alpha).log() + * overlapping_mask + ) + cls_logits_unseen = ( + (query_logit ** (1 - beta) * clip_logit ** beta).log() + * (1 - overlapping_mask) + ) + cls_results = cls_logits_seen + cls_logits_unseen + is_void_prob = F.softmax(mask_cls_results, dim=-1)[..., -1:] + mask_cls_results = torch.cat([ + cls_results.softmax(-1) * (1.0 - is_void_prob), is_void_prob], dim=-1) + mask_cls_results = torch.log(mask_cls_results + 1e-8) + return mask_cls_results diff --git a/seg/models/fusion_head/__init__.py b/seg/models/fusion_head/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67f66084e5696854aa7fff60dd2c6929902f62ec --- /dev/null +++ b/seg/models/fusion_head/__init__.py @@ -0,0 +1 @@ +from .omgseg_fusionhead import OMGFusionHead diff --git a/seg/models/fusion_head/omgseg_fusionhead.py b/seg/models/fusion_head/omgseg_fusionhead.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b6924671d20788c76e28aae6d07073660d8d2c --- /dev/null +++ b/seg/models/fusion_head/omgseg_fusionhead.py @@ -0,0 +1,297 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn.functional as F +from mmengine.structures import InstanceData, PixelData +from torch import Tensor + +from mmdet.evaluation.functional import INSTANCE_OFFSET +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.mask import mask2bbox +from mmdet.utils import OptConfigType, OptMultiConfig +from mmdet.models.seg_heads.panoptic_fusion_heads.base_panoptic_fusion_head import BasePanopticFusionHead + + +@MODELS.register_module() +class OMGFusionHead(BasePanopticFusionHead): + + def __init__( + self, + num_things_classes: int = 80, + num_stuff_classes: int = 53, + test_cfg: OptConfigType = None, + loss_panoptic: OptConfigType = None, + init_cfg: OptMultiConfig = None, + **kwargs + ): + super().__init__( + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + test_cfg=test_cfg, + loss_panoptic=loss_panoptic, + init_cfg=init_cfg, + **kwargs) + + def loss(self, **kwargs): + """MaskFormerFusionHead has no training loss.""" + return dict() + + def panoptic_postprocess(self, mask_cls: Tensor, + mask_pred: Tensor) -> PixelData: + """Panoptic segmengation inference. + + Args: + mask_cls (Tensor): Classfication outputs of shape + (num_queries, cls_out_channels) for a image. + Note `cls_out_channels` should includes + background. + mask_pred (Tensor): Mask outputs of shape + (num_queries, h, w) for a image. + + Returns: + :obj:`PixelData`: Panoptic segment result of shape \ + (h, w), each element in Tensor means: \ + ``segment_id = _cls + instance_id * INSTANCE_OFFSET``. + """ + object_mask_thr = self.test_cfg.get('object_mask_thr', 0.8) + iou_thr = self.test_cfg.get('iou_thr', 0.8) + filter_low_score = self.test_cfg.get('filter_low_score', False) + + scores, labels = F.softmax(mask_cls, dim=-1).max(-1) + mask_pred = mask_pred.sigmoid() + + keep = labels.ne(self.num_classes) & (scores > object_mask_thr) + cur_scores = scores[keep] + cur_classes = labels[keep] + cur_masks = mask_pred[keep] + + cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks + + h, w = cur_masks.shape[-2:] + panoptic_seg = torch.full((h, w), + self.num_classes, + dtype=torch.int32, + device=cur_masks.device) + if cur_masks.shape[0] == 0: + # We didn't detect any mask :( + pass + else: + cur_mask_ids = cur_prob_masks.argmax(0) + instance_id = 1 + for k in range(cur_classes.shape[0]): + pred_class = int(cur_classes[k].item()) + isthing = pred_class < self.num_things_classes + mask = cur_mask_ids == k + mask_area = mask.sum().item() + original_area = (cur_masks[k] >= 0.5).sum().item() + + if filter_low_score: + mask = mask & (cur_masks[k] >= 0.5) + + if mask_area > 0 and original_area > 0: + if mask_area / original_area < iou_thr: + continue + + if not isthing: + # different stuff regions of same class will be + # merged here, and stuff share the instance_id 0. + panoptic_seg[mask] = pred_class + else: + panoptic_seg[mask] = ( + pred_class + instance_id * INSTANCE_OFFSET) + instance_id += 1 + + return PixelData(sem_seg=panoptic_seg[None]) + + def semantic_postprocess(self, mask_cls: Tensor, + mask_pred: Tensor) -> PixelData: + """Semantic segmengation postprocess. + + Args: + mask_cls (Tensor): Classfication outputs of shape + (num_queries, cls_out_channels) for a image. + Note `cls_out_channels` should includes + background. + mask_pred (Tensor): Mask outputs of shape + (num_queries, h, w) for a image. + + Returns: + :obj:`PixelData`: Semantic segment result. + """ + # TODO add semantic segmentation result + raise NotImplementedError + + def instance_postprocess(self, mask_cls: Tensor, + mask_pred: Tensor) -> InstanceData: + """Instance segmengation postprocess. + + Args: + mask_cls (Tensor): Classfication outputs of shape + (num_queries, cls_out_channels) for a image. + Note `cls_out_channels` should includes + background. + mask_pred (Tensor): Mask outputs of shape + (num_queries, h, w) for a image. + + Returns: + :obj:`InstanceData`: Instance segmentation results. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + max_per_image = self.test_cfg.get('max_per_image', 100) + num_queries = mask_cls.shape[0] + # shape (num_queries, num_class) + scores = F.softmax(mask_cls, dim=-1)[:, :-1] + # shape (num_queries * num_class, ) + labels = torch.arange(self.num_classes, device=mask_cls.device). \ + unsqueeze(0).repeat(num_queries, 1).flatten(0, 1) + scores_per_image, top_indices = scores.flatten(0, 1).topk( + max_per_image, sorted=False) + labels_per_image = labels[top_indices] + + query_indices = top_indices // self.num_classes + mask_pred = mask_pred[query_indices] + + # extract things + is_thing = labels_per_image < self.num_things_classes + scores_per_image = scores_per_image[is_thing] + labels_per_image = labels_per_image[is_thing] + mask_pred = mask_pred[is_thing] + + mask_pred_binary = (mask_pred > 0).float() + mask_scores_per_image = (mask_pred.sigmoid() * + mask_pred_binary).flatten(1).sum(1) / ( + mask_pred_binary.flatten(1).sum(1) + 1e-6) + det_scores = scores_per_image * mask_scores_per_image + mask_pred_binary = mask_pred_binary.bool() + bboxes = mask2bbox(mask_pred_binary) + + results = InstanceData() + results.bboxes = bboxes + results.labels = labels_per_image + results.scores = det_scores + results.masks = mask_pred_binary + return results + + def proposal_postprocess(self, mask_score: Tensor, mask_pred: Tensor) -> InstanceData: + max_per_image = self.test_cfg.get('num_proposals', 10) + h, w = mask_pred.shape[-2:] + # shape (num_queries, num_class) + scores = mask_score.sigmoid().squeeze(-1) + scores_per_image, top_indices = scores.topk(max_per_image, sorted=True) + + mask_selected = mask_pred[top_indices] + + proposals = [] + for idx in range(len(mask_selected)): + mask = mask_selected[len(mask_selected) - idx - 1] + proposals.append(mask.sigmoid() > .5) + seg_map = torch.stack(proposals) + return seg_map + + def predict(self, + mask_cls_results: Tensor, + mask_pred_results: Tensor, + batch_data_samples: SampleList, + iou_results=None, + rescale: bool = False, + **kwargs) -> List[dict]: + """Test segment without test-time aumengtation. + + Only the output of last decoder layers was used. + + Args: + mask_cls_results (Tensor): Mask classification logits, + shape (batch_size, num_queries, cls_out_channels). + Note `cls_out_channels` should includes background. + mask_pred_results (Tensor): Mask logits, shape + (batch_size, num_queries, h, w). + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + iou_results: None + rescale (bool): If True, return boxes in + original image space. Default False. + + Returns: + list[dict]: Instance segmentation \ + results and panoptic segmentation results for each \ + image. + + .. code-block:: none + + [ + { + 'pan_results': PixelData, + 'ins_results': InstanceData, + # semantic segmentation results are not supported yet + 'sem_results': PixelData + }, + ... + ] + """ + batch_img_metas = [ + data_sample.metainfo for data_sample in batch_data_samples + ] + panoptic_on = self.test_cfg.get('panoptic_on', True) + semantic_on = self.test_cfg.get('semantic_on', False) + instance_on = self.test_cfg.get('instance_on', False) + proposal_on = self.test_cfg.get('proposal_on', False) + assert not semantic_on, 'segmantic segmentation ' \ + 'results are not supported yet.' + + results = [] + idx = 0 + for mask_cls_result, mask_pred_result, meta in zip( + mask_cls_results, mask_pred_results, batch_img_metas): + # remove padding + img_height, img_width = meta['img_shape'][:2] + mask_pred_result = mask_pred_result.to(mask_cls_results.device) + mask_pred_result = mask_pred_result[:, :img_height, :img_width] + + if rescale: + # return result in original resolution + ori_height, ori_width = meta['ori_shape'][:2] + mask_pred_result = F.interpolate( + mask_pred_result[:, None], + size=(ori_height, ori_width), + mode='bilinear', + align_corners=False)[:, 0] + + result = dict() + if panoptic_on: + pan_results = self.panoptic_postprocess( + mask_cls_result, mask_pred_result + ) + result['pan_results'] = pan_results + + if instance_on: + ins_results = self.instance_postprocess( + mask_cls_result, mask_pred_result + ) + result['ins_results'] = ins_results + + if semantic_on: + sem_results = self.semantic_postprocess( + mask_cls_result, mask_pred_result + ) + result['sem_results'] = sem_results + + if proposal_on: + pro_results = self.proposal_postprocess( + iou_results[idx], mask_pred_result + ) + result['pro_results'] = pro_results + + results.append(result) + idx += 1 + + return results diff --git a/seg/models/heads/__init__.py b/seg/models/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6a3eaccf753c4c55045c54562d17b55fc103267 --- /dev/null +++ b/seg/models/heads/__init__.py @@ -0,0 +1 @@ +from .mask2former_vid import Mask2FormerVideoHead diff --git a/seg/models/heads/mask2former_vid.py b/seg/models/heads/mask2former_vid.py new file mode 100644 index 0000000000000000000000000000000000000000..dc5d5726418263d97b8be00e2384226b659c969e --- /dev/null +++ b/seg/models/heads/mask2former_vid.py @@ -0,0 +1,1098 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from mmcv.cnn import Conv2d +from mmcv.ops import point_sample +from mmdet.models import Mask2FormerTransformerDecoder, inverse_sigmoid, coordinate_to_encoding +from mmdet.structures.bbox import bbox_xyxy_to_cxcywh +from mmengine import print_log +from mmengine.dist import get_dist_info +from mmengine.model import caffe2_xavier_init, ModuleList +from mmengine.structures import InstanceData, PixelData +from torch import Tensor + +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import SampleList, TrackDataSample +from mmdet.utils import (ConfigType, InstanceList, OptConfigType, + OptMultiConfig, reduce_mean) +from mmdet.models.layers import SinePositionalEncoding3D +from mmdet.models.utils import multi_apply, preprocess_panoptic_gt, get_uncertain_point_coords_with_randomness +from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead + +from seg.models.utils import preprocess_video_panoptic_gt, mask_pool +from seg.models.utils.load_checkpoint import load_checkpoint_with_prefix + + +@MODELS.register_module() +class Mask2FormerVideoHead(AnchorFreeHead): + """Implements the Mask2Former head. + + See `Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for features. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer decoder. + pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel + decoder. Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add + a layer to change the embed_dim of tranformer encoder in + pixel decoder to the embed_dim of transformer decoder. + Defaults to False. + transformer_decoder (:obj:`ConfigDict` or dict): Config for + transformer decoder. Defaults to None. + positional_encoding (:obj:`ConfigDict` or dict): Config for + transformer decoder position encoding. Defaults to + dict(num_feats=128, normalize=True). + loss_cls (:obj:`ConfigDict` or dict): Config of the classification + loss. Defaults to None. + loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss. + Defaults to None. + loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss. + Defaults to None. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config of + Mask2Former head. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + Mask2Former head. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ + dict], optional): Initialization config dict. Defaults to None. + """ + + def __init__(self, + in_channels: List[int], + feat_channels: int, + out_channels: int, + num_things_classes: int = 80, + num_stuff_classes: int = 53, + num_queries: int = 100, + num_transformer_feat_level: int = 3, + pixel_decoder: ConfigType = ..., + enforce_decoder_input_project: bool = False, + transformer_decoder: ConfigType = ..., + positional_encoding: ConfigType = None, + loss_cls: ConfigType = None, + loss_mask: ConfigType = None, + loss_dice: ConfigType = None, + loss_iou: ConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None, + # ov configs + sphere_cls: bool = False, + ov_classifier_name: Optional[str] = None, + logit: Optional[int] = None, + # box sup + matching_whole_map: bool = False, + # box query + enable_box_query: bool = False, + group_assigner: OptConfigType = None, + **kwargs) -> None: + super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + self.num_transformer_feat_level = num_transformer_feat_level + self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads + self.num_transformer_decoder_layers = transformer_decoder.num_layers + assert pixel_decoder.encoder.layer_cfg. \ + self_attn_cfg.num_levels == num_transformer_feat_level + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update( + in_channels=in_channels, + feat_channels=feat_channels, + out_channels=out_channels) + self.pixel_decoder = MODELS.build(pixel_decoder_) + self.transformer_decoder = Mask2FormerTransformerDecoder( + **transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + + self.decoder_input_projs = ModuleList() + # from low resolution to high resolution + for _ in range(num_transformer_feat_level): + if (self.decoder_embed_dims != feat_channels + or enforce_decoder_input_project): + self.decoder_input_projs.append( + Conv2d( + feat_channels, self.decoder_embed_dims, kernel_size=1)) + else: + self.decoder_input_projs.append(nn.Identity()) + self.decoder_positional_encoding = SinePositionalEncoding3D( + **positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, feat_channels) + self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # from low resolution to high resolution + self.level_embed = nn.Embedding(self.num_transformer_feat_level, + feat_channels) + + if not sphere_cls: + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels)) + + if loss_iou is not None: + self.iou_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, 1)) + else: + self.iou_embed = None + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) + self.sampler = TASK_UTILS.build( + self.train_cfg['sampler'], default_args=dict(context=self)) + self.num_points = self.train_cfg.get('num_points', 12544) + self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) + self.importance_sample_ratio = self.train_cfg.get( + 'importance_sample_ratio', 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = MODELS.build(loss_cls) + self.loss_mask = MODELS.build(loss_mask) + self.loss_dice = MODELS.build(loss_dice) + if loss_iou is not None: + self.loss_iou = MODELS.build(loss_iou) + else: + self.loss_iou = None + + # prepare OV things + # OV cls embed + if sphere_cls: + rank, world_size = get_dist_info() + if ov_classifier_name is None: + _dim = 1024 # temporally hard code + cls_embed = torch.empty(self.num_classes, _dim) + torch.nn.init.orthogonal_(cls_embed) + cls_embed = cls_embed[:, None] + else: + # ov_path = os.path.join(os.path.expanduser('~/.cache/embd'), f"{ov_classifier_name}.pth") + ov_path = os.path.join(os.path.expanduser('./models/'), f"{ov_classifier_name}.pth") + cls_embed = torch.load(ov_path) + cls_embed_norm = cls_embed.norm(p=2, dim=-1) + assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm)) + if self.loss_cls and self.loss_cls.use_sigmoid: + pass + else: + _dim = cls_embed.size(2) + _prototypes = cls_embed.size(1) + + if rank == 0: + back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda') + # back_token = back_token / back_token.norm(p=2, dim=-1, keepdim=True) + else: + back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda') + if world_size > 1: + dist.broadcast(back_token, src=0) + back_token = back_token.to(device='cpu') + cls_embed = torch.cat([ + cls_embed, back_token.repeat(_prototypes, 1)[None] + ], dim=0) + self.register_buffer('cls_embed', cls_embed.permute(2, 0, 1).contiguous(), persistent=False) + + # cls embd proj + cls_embed_dim = self.cls_embed.size(0) + self.cls_proj = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, cls_embed_dim) + ) + + # Haobo Yuan: + # For the logit_scale, I refer to this issue. + # https://github.com/openai/CLIP/issues/46#issuecomment-945062212 + # https://github.com/openai/CLIP/issues/46#issuecomment-782558799 + # Based on my understanding, it is a mistake of CLIP. + # Because they mention that they refer to InstDisc (Wu, 2018) paper. + # InstDisc set a non-learnable temperature to np.log(1 / 0.07). + # 4.6052 is np.log(1 / 0.01) + # np.log(1 / 0.07) will be fast converged to np.log(1 / 0.01) + if logit is None: + logit_scale = torch.tensor(4.6052, dtype=torch.float32) + else: + logit_scale = torch.tensor(logit, dtype=torch.float32) + self.register_buffer('logit_scale', logit_scale, persistent=False) + + # Mask Pooling + self.mask_pooling = mask_pool + self.mask_pooling_proj = nn.Sequential( + nn.LayerNorm(feat_channels), + nn.Linear(feat_channels, feat_channels) + ) + + # box inst + self.matching_whole_map = matching_whole_map + + # enable box query + self.enable_box_query = enable_box_query + if self.enable_box_query: + self.num_mask_tokens = 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, feat_channels) + self.pb_embedding = nn.Embedding(2, feat_channels) + self.pos_linear = nn.Linear(2 * feat_channels, feat_channels) + + def init_weights(self) -> None: + if self.init_cfg['type'] == 'Pretrained': + checkpoint_path = self.init_cfg['checkpoint'] + state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix']) + msg = self.load_state_dict(state_dict, strict=False) + print_log(f"m: {msg[0]} \n u: {msg[1]}", logger='current') + return None + + for m in self.decoder_input_projs: + if isinstance(m, Conv2d): + caffe2_xavier_init(m, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def preprocess_gt( + self, batch_gt_instances: InstanceList, + batch_gt_semantic_segs: List[Optional[PixelData]]) -> InstanceList: + """Preprocess the ground truth for all images. + + Args: + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``labels``, each is + ground truth labels of each bbox, with shape (num_gts, ) + and ``masks``, each is ground truth masks of each instances + of a image, shape (num_gts, h, w). + batch_gt_semantic_segs (list[Optional[PixelData]]): Ground truth of + semantic segmentation, each with the shape (1, h, w). + [0, num_thing_class - 1] means things, + [num_thing_class, num_class-1] means stuff, + 255 means VOID. It's None when training instance segmentation. + + Returns: + list[obj:`InstanceData`]: each contains the following keys + + - labels (Tensor): Ground truth class indices\ + for a image, with shape (n, ), n is the sum of\ + number of stuff type and number of instance in a image. + - masks (Tensor): Ground truth mask for a\ + image, with shape (n, h, w). + """ + num_things_list = [self.num_things_classes] * len(batch_gt_instances) + num_stuff_list = [self.num_stuff_classes] * len(batch_gt_instances) + if isinstance(batch_gt_instances[0], List): + gt_labels_list = [ + [torch.stack([torch.ones_like(gt_instances['labels']) * frame_id, gt_instances['labels']], dim=1) + for frame_id, gt_instances in enumerate(gt_vid_instances)] + for gt_vid_instances in batch_gt_instances + ] + gt_labels_list = [torch.cat(gt_labels, dim=0) for gt_labels in gt_labels_list] + gt_masks_list = [ + [gt_instances['masks'] for gt_instances in gt_vid_instances] + for gt_vid_instances in batch_gt_instances + ] + gt_semantic_segs = [ + [None if gt_semantic_seg is None else gt_semantic_seg.sem_seg + for gt_semantic_seg in gt_vid_semantic_segs] + for gt_vid_semantic_segs in batch_gt_semantic_segs + ] + if gt_semantic_segs[0][0] is None: + gt_semantic_segs = [None] * len(batch_gt_instances) + else: + gt_semantic_segs = [torch.stack(gt_sem_seg, dim=0) for gt_sem_seg in gt_semantic_segs] + gt_instance_ids_list = [ + [torch.stack([torch.ones_like(gt_instances['instances_ids']) * frame_id, gt_instances['instances_ids']], + dim=1) + for frame_id, gt_instances in enumerate(gt_vid_instances)] + for gt_vid_instances in batch_gt_instances + ] + gt_instance_ids_list = [torch.cat(gt_instance_ids, dim=0) for gt_instance_ids in gt_instance_ids_list] + targets = multi_apply(preprocess_video_panoptic_gt, gt_labels_list, + gt_masks_list, gt_semantic_segs, gt_instance_ids_list, + num_things_list, num_stuff_list) + else: + gt_labels_list = [ + gt_instances['labels'] for gt_instances in batch_gt_instances + ] + gt_masks_list = [ + gt_instances['masks'] for gt_instances in batch_gt_instances + ] + gt_semantic_segs = [ + None if gt_semantic_seg is None else gt_semantic_seg.sem_seg + for gt_semantic_seg in batch_gt_semantic_segs + ] + targets = multi_apply(preprocess_panoptic_gt, gt_labels_list, + gt_masks_list, gt_semantic_segs, num_things_list, + num_stuff_list) + labels, masks = targets + batch_gt_instances = [ + InstanceData(labels=label, masks=mask) + for label, mask in zip(labels, masks) + ] + return batch_gt_instances + + def get_queries(self, batch_data_samples): + img_size = batch_data_samples[0].batch_input_shape + query_feat_list = [] + bp_list = [] + for idx, data_sample in enumerate(batch_data_samples): + is_box = data_sample.gt_instances.bp.eq(0) + is_point = data_sample.gt_instances.bp.eq(1) + assert is_box.any() + sparse_embed, _ = self.pe( + data_sample.gt_instances[is_box], + image_size=img_size, + with_bboxes=True, + with_points=False, + ) + sparse_embed = [sparse_embed] + if is_point.any(): + _sparse_embed, _ = self.pe( + data_sample.gt_instances[is_point], + image_size=img_size, + with_bboxes=False, + with_points=True, + ) + sparse_embed.append(_sparse_embed) + sparse_embed = torch.cat(sparse_embed) + assert len(sparse_embed) == len(data_sample.gt_instances) + + query_feat_list.append(self.query_proj(sparse_embed.flatten(1, 2))) + bp_list.append(data_sample.gt_instances.bp) + + query_feat = torch.stack(query_feat_list) + bp_labels = torch.stack(bp_list).to(dtype=torch.long) + bp_embed = self.bp_embedding.weight[bp_labels] + bp_embed = bp_embed.repeat_interleave(self.num_mask_tokens, dim=1) + + query_feat = query_feat + bp_embed + return query_feat, None + + def get_targets( + self, + cls_scores_list: List[Tensor], + mask_preds_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + return_sampling_results: bool = False + ) -> Tuple[List[Union[Tensor, int]]]: + """Compute classification and mask targets for all images for a decoder + layer. + + Args: + cls_scores_list (list[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape (num_queries, + cls_out_channels). + mask_preds_list (list[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape (num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + batch_img_metas (list[dict]): List of image meta information. + return_sampling_results (bool): Whether to return the sampling + results. Defaults to False. + + Returns: + tuple: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels of all images.\ + Each with shape (num_queries, ). + - label_weights_list (list[Tensor]): Label weights\ + of all images. Each with shape (num_queries, ). + - mask_targets_list (list[Tensor]): Mask targets of\ + all images. Each with shape (num_queries, h, w). + - mask_weights_list (list[Tensor]): Mask weights of\ + all images. Each with shape (num_queries, ). + - avg_factor (int): Average factor that is used to average\ + the loss. When using sampling method, avg_factor is + usually the sum of positive and negative priors. When + using `MaskPseudoSampler`, `avg_factor` is usually equal + to the number of positive priors. + + additional_returns: This function enables user-defined returns from + `self._get_targets_single`. These returns are currently refined + to properties at each feature map (i.e. having HxW dimension). + The results will be concatenated after the end. + """ + results = multi_apply( + self._get_targets_single, cls_scores_list, mask_preds_list, batch_gt_instances, batch_img_metas + ) + labels_list, label_weights_list, mask_targets_list, mask_weights_list, \ + pos_inds_list, neg_inds_list, sampling_results_list = results[:7] + rest_results = list(results[7:]) + + avg_factor = sum([results.avg_factor for results in sampling_results_list]) + res = (labels_list, label_weights_list, mask_targets_list, mask_weights_list, avg_factor) + + if return_sampling_results: + res = res + sampling_results_list + + return res + tuple(rest_results) + + def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor, + gt_instances: InstanceData, + img_meta: dict) -> Tuple[Tensor]: + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_instances (:obj:`InstanceData`): It contains ``labels`` and + ``masks``. + img_meta (dict): Image informtation. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each \ + image. + - neg_inds (Tensor): Sampled negative indices for each \ + image. + - sampling_result (:obj:`SamplingResult`): Sampling results. + """ + gt_labels = gt_instances.labels + gt_masks = gt_instances.masks + + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + if not self.matching_whole_map: + point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample(mask_pred.unsqueeze(1), + point_coords.repeat(num_queries, 1, 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), + point_coords.repeat(num_gts, 1, 1)).squeeze(1) + else: + mask_points_pred = mask_pred + gt_points_masks = gt_masks + + sampled_gt_instances = InstanceData( + labels=gt_labels, masks=gt_points_masks) + sampled_pred_instances = InstanceData( + scores=cls_score, masks=mask_points_pred) + # assign and sample + assign_result = self.assigner.assign( + pred_instances=sampled_pred_instances, + gt_instances=sampled_gt_instances, + img_meta=img_meta + ) + pred_instances = InstanceData(scores=cls_score, masks=mask_pred) + sampling_result = self.sampler.sample( + assign_result=assign_result, + pred_instances=pred_instances, + gt_instances=gt_instances) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((num_queries,), self.num_classes, dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones((num_queries,)) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((num_queries,)) + mask_weights[pos_inds] = 1.0 + + return labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds, sampling_result + + def loss_by_feat(self, all_cls_scores: Tensor, all_mask_preds: Tensor, + batch_gt_instances: List[InstanceData], + batch_img_metas: List[dict]) -> Dict[str, Tensor]: + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape (num_decoder, batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape (num_decoder, batch_size, num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + batch_img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + batch_gt_instances_list = [ + batch_gt_instances for _ in range(num_dec_layers) + ] + img_metas_list = [batch_img_metas for _ in range(num_dec_layers)] + losses_cls, losses_mask, losses_dice = multi_apply( + self._loss_by_feat_single, all_cls_scores, all_mask_preds, batch_gt_instances_list, img_metas_list + ) + + loss_dict = dict() + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i + loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i + num_dec_layer += 1 + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_mask'] = losses_mask[-1] + loss_dict['loss_dice'] = losses_dice[-1] + return loss_dict + + def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor, + batch_gt_instances: List[InstanceData], + batch_img_metas: List[dict]) -> Tuple[Tensor]: + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, h, w). + batch_gt_instances (list[obj:`InstanceData`]): each contains + ``labels`` and ``masks``. + batch_img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single \ + decoder layer. + """ + batch_size, num_ins = cls_scores.size(0), cls_scores.size(1) + # hack here: + is_sam = num_ins != self.num_queries + + if not is_sam: + cls_scores_list = [cls_scores[i] for i in range(batch_size)] + mask_preds_list = [mask_preds[i] for i in range(batch_size)] + labels_list, label_weights_list, mask_targets_list, mask_weights_list, avg_factor = \ + self.get_targets(cls_scores_list, mask_preds_list, batch_gt_instances, batch_img_metas) + labels = torch.stack(labels_list, dim=0) + label_weights = torch.stack(label_weights_list, dim=0) + mask_targets = torch.cat(mask_targets_list, dim=0) + mask_weights = torch.stack(mask_weights_list, dim=0) + else: + labels = torch.stack([item.labels for item in batch_gt_instances]) + label_weights = labels.new_ones((batch_size, num_ins), dtype=torch.float) + mask_targets = torch.cat([item.masks for item in batch_gt_instances]) + mask_weights = mask_targets.new_ones((batch_size, num_ins), dtype=torch.float) + avg_factor = cls_scores.size(1) + + # classification loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + class_weight = cls_scores.new_tensor(self.class_weight) + ignore_inds = labels.eq(-1.) + # zero will not be involved in the loss cal + labels[ignore_inds] = 0 + label_weights[ignore_inds] = 0. + obj_inds = labels.eq(self.num_classes) + if is_sam: + cls_avg_factor = cls_scores.new_tensor([0]) + else: + cls_avg_factor = class_weight[labels].sum() + cls_avg_factor = reduce_mean(cls_avg_factor) + cls_avg_factor = max(cls_avg_factor, 1) + if self.loss_iou is not None: + loss_cls = self.loss_cls( + cls_scores[..., :-1], + labels, + label_weights, + avg_factor=cls_avg_factor + ) + loss_iou = self.loss_iou( + cls_scores[..., -1:], + obj_inds.to(dtype=torch.long), + avg_factor=cls_avg_factor + ) + if is_sam: + loss_iou = loss_iou * 0 + loss_cls = loss_cls + loss_iou + else: + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=cls_avg_factor + ) + + # loss_mask + num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor])) + num_total_masks = max(num_total_masks, 1) + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + if not self.matching_whole_map: + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, self.num_points, + self.oversample_ratio, self.importance_sample_ratio) + # shape (num_total_gts, h, w) -> (num_total_gts, num_points) + mask_point_targets = point_sample( + mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample( + mask_preds.unsqueeze(1), points_coords).squeeze(1) + else: + mask_point_targets = mask_targets + mask_point_preds = mask_preds + + # dice loss + loss_dice = self.loss_dice( + mask_point_preds, mask_point_targets, avg_factor=num_total_masks) + + # mask loss + # shape (num_queries, num_points) -> (num_queries * num_points, ) + mask_point_preds = mask_point_preds.reshape(-1) + # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) + mask_point_targets = mask_point_targets.reshape(-1) + loss_mask = self.loss_mask( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks * self.num_points + ) + + return loss_cls, loss_mask, loss_dice + + def forward_logit(self, cls_embd): + cls_pred = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.cls_embed) + cls_pred = cls_pred.max(-1).values + cls_pred = self.logit_scale.exp() * cls_pred + return cls_pred + + def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor, + attn_mask_target_size: Tuple[int, int], + num_frames: int = 0) -> Tuple[Tensor]: + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (batch_size, num_queries, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + - num_frames: How many frames are there in video. + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + # shape (num_queries, batch_size, c) + if isinstance(self.cls_embed, nn.Module): + cls_pred = self.cls_embed(decoder_out) + # shape (num_queries, batch_size, c) + mask_embed = self.mask_embed(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) + + if not isinstance(self.cls_embed, nn.Module): + maskpool_embd = self.mask_pooling(x=mask_feature, mask=mask_pred.detach()) + maskpool_embd = self.mask_pooling_proj(maskpool_embd) + cls_embd = self.cls_proj(maskpool_embd + decoder_out) + cls_pred = self.forward_logit(cls_embd) + + if self.iou_embed is not None: + iou_pred = self.iou_embed(decoder_out) + cls_pred = torch.cat([cls_pred, iou_pred], dim=-1) + + if num_frames > 0: + assert len(mask_pred.shape) == 4 + assert mask_pred.shape[2] % num_frames == 0 + frame_h = mask_pred.shape[2] // num_frames + num_q = mask_pred.shape[1] + _mask_pred = mask_pred.unflatten(-2, (num_frames, frame_h)).flatten(1, 2) + attn_mask = F.interpolate( + _mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + attn_mask = attn_mask.unflatten(1, (num_q, num_frames)).flatten(2, 3) + else: + attn_mask = F.interpolate( + mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( + (1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, attn_mask + + def forward(self, x: List[Tensor], batch_data_samples: SampleList) -> Tuple[List[Tensor]]: + """Forward function. + + Args: + x (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple[list[Tensor]]: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + batch_img_metas = [] + if isinstance(batch_data_samples[0], TrackDataSample): + for track_sample in batch_data_samples: + cur_list = [] + for det_sample in track_sample: + cur_list.append(det_sample.metainfo) + batch_img_metas.append(cur_list) + num_frames = len(batch_img_metas[0]) + else: + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + num_frames = 0 + batch_size = len(batch_img_metas) + + mask_features, multi_scale_memorys = self.pixel_decoder(x) + if num_frames > 0: + mask_features = mask_features.unflatten(0, (batch_size, num_frames)) + mask_features = mask_features.transpose(1, 2).flatten(2, 3) + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + decoder_input = decoder_input.flatten(2).permute(0, 2, 1) + if num_frames > 0: + decoder_input = decoder_input.unflatten(0, (batch_size, num_frames)) + decoder_input = decoder_input.flatten(1, 2) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + + # shape (batch_size, c, h, w) -> (batch_size, h*w, c) + num_frames_real = 1 if num_frames == 0 else num_frames + mask = decoder_input.new_zeros( + (batch_size, num_frames_real) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.transpose( + 1, 2).flatten(2).permute(0, 2, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + + if self.enable_box_query and batch_data_samples[0].data_tag in ['sam_mul', 'sam']: + query_feat, input_query_bbox, self_attn_mask, _ = self.prepare_for_dn_mo(batch_data_samples) + query_embed = coordinate_to_encoding(input_query_bbox.sigmoid()) + query_embed = self.pos_linear(query_embed) + else: + # coco style query generation + # shape (num_queries, c) -> (batch_size, num_queries, c) + query_feat = self.query_feat.weight.unsqueeze(0).repeat((batch_size, 1, 1)) + query_embed = self.query_embed.weight.unsqueeze(0).repeat((batch_size, 1, 1)) + self_attn_mask = None + + cls_pred_list = [] + mask_pred_list = [] + cls_pred, mask_pred, attn_mask = self._forward_head( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:], + num_frames=num_frames + ) + cls_pred_list.append(cls_pred) + if num_frames > 0: + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + cross_attn_mask=attn_mask, + self_attn_mask=self_attn_mask, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None) + cls_pred, mask_pred, attn_mask = self._forward_head( + query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:], + num_frames=num_frames + ) + + cls_pred_list.append(cls_pred) + if num_frames > 0: + mask_pred = mask_pred.unflatten(2, (num_frames, -1)) + mask_pred_list.append(mask_pred) + + return cls_pred_list, mask_pred_list, query_feat + + def loss( + self, + x: Tuple[Tensor], + batch_data_samples: SampleList, + ) -> Dict[str, Tensor]: + """Perform forward propagation and loss calculation of the panoptic + head on the features of the upstream network. + + Args: + x (tuple[Tensor]): Multi-level features from the upstream + network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + batch_img_metas = [] + batch_gt_instances = [] + batch_gt_semantic_segs = [] + for data_sample in batch_data_samples: + if isinstance(data_sample, TrackDataSample): + clip_meta = [] + clip_instances = [] + clip_sem_seg = [] + for det_sample in data_sample: + clip_meta.append(det_sample.metainfo) + clip_instances.append(det_sample.gt_instances) + if 'gt_sem_seg' in det_sample: + clip_sem_seg.append(det_sample.gt_sem_seg) + else: + clip_sem_seg.append(None) + batch_img_metas.append(clip_meta) + batch_gt_instances.append(clip_instances) + batch_gt_semantic_segs.append(clip_sem_seg) + else: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + if 'gt_sem_seg' in data_sample: + batch_gt_semantic_segs.append(data_sample.gt_sem_seg) + else: + batch_gt_semantic_segs.append(None) + + # forward + all_cls_scores, all_mask_preds, _ = self(x, batch_data_samples) + + # preprocess ground truth + if not self.enable_box_query or batch_data_samples[0].data_tag in ['coco', 'sam']: + batch_gt_instances = self.preprocess_gt(batch_gt_instances, batch_gt_semantic_segs) + + # loss + if isinstance(batch_data_samples[0], TrackDataSample): + num_frames = len(batch_img_metas[0]) + all_mask_preds = [mask.flatten(2, 3) for mask in all_mask_preds] + for instance in batch_gt_instances: + instance['masks'] = instance['masks'].flatten(1, 2) + film_metas = [ + { + 'img_shape': (meta[0]['img_shape'][0] * num_frames, + meta[0]['img_shape'][1]) + } for meta in batch_img_metas + ] + batch_img_metas = film_metas + + losses = self.loss_by_feat(all_cls_scores, all_mask_preds, batch_gt_instances, batch_img_metas) + + if self.enable_box_query: + losses['loss_zero'] = 0 * self.query_feat.weight.sum() + 0 * self.query_embed.weight.sum() + losses['loss_zero'] += 0 * self.pb_embedding.weight.sum() + losses['loss_zero'] += 0 * self.mask_tokens.weight.sum() + for name, param in self.pos_linear.named_parameters(): + losses['loss_zero'] += 0 * param.sum() + return losses + + def predict(self, x: Tuple[Tensor], + batch_data_samples: SampleList, + return_query=False, + ) -> Tuple[Tensor, ...]: + """Test without augmentaton. + + Args: + return_query: + x (tuple[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + batch_data_samples (List[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + tuple[Tensor]: A tuple contains two tensors. + + - mask_cls_results (Tensor): Mask classification logits,\ + shape (batch_size, num_queries, cls_out_channels). + Note `cls_out_channels` should includes background. + - mask_pred_results (Tensor): Mask logits, shape \ + (batch_size, num_queries, h, w). + """ + data_sample = batch_data_samples[0] + if isinstance(data_sample, TrackDataSample): + img_shape = data_sample[0].metainfo['batch_input_shape'] + num_frames = len(data_sample) + else: + img_shape = data_sample.metainfo['batch_input_shape'] + num_frames = 0 + all_cls_scores, all_mask_preds, query_feat = self(x, batch_data_samples) + if self.iou_embed is not None: + _all_cls_scores = [cls_score[..., :-1] for cls_score in all_cls_scores] + iou_results = [cls_score[..., -1:] for cls_score in all_cls_scores] + all_cls_scores = _all_cls_scores + else: + iou_results = None + mask_cls_results = all_cls_scores[-1] + mask_pred_results = all_mask_preds[-1] + if iou_results is not None: + iou_results = iou_results[-1] + + if num_frames > 0: + mask_pred_results = mask_pred_results.flatten(1, 2) + mask_pred_results = F.interpolate( + mask_pred_results, + size=(img_shape[0], img_shape[1]), + mode='bilinear', + align_corners=False) + if num_frames > 0: + num_queries = mask_cls_results.shape[1] + mask_pred_results = mask_pred_results.unflatten(1, (num_queries, num_frames)) + + if iou_results is None: + return mask_cls_results, mask_pred_results + + if return_query: + return mask_cls_results, mask_pred_results, query_feat, iou_results + else: + return mask_cls_results, mask_pred_results, iou_results + + def prepare_for_dn_mo(self, batch_data_samples): + scalar, noise_scale = 100, 0.4 + gt_instances = [t.gt_instances for t in batch_data_samples] + + point_coords = torch.stack([inst.point_coords for inst in gt_instances]) + pb_labels = torch.stack([inst['bp'] for inst in gt_instances]) + labels = torch.zeros_like(pb_labels).long() + + boxes = point_coords # + boxes + + factors = [] + for i, data_sample in enumerate(batch_data_samples): + h, w, = data_sample.metainfo['img_shape'] + factor = boxes[i].new_tensor([w, h, w, h]).unsqueeze(0).repeat(boxes[i].size(0), 1) + factors.append(factor) + factors = torch.stack(factors, 0) + + boxes = bbox_xyxy_to_cxcywh(boxes / factors) # xyxy / factor or xywh / factor ???? + # box_start = [t['box_start'] for t in targets] + box_start = [len(point) for point in point_coords] + + known_labels = labels + known_pb_labels = pb_labels + known_bboxs = boxes + + known_labels_expaned = known_labels.clone() + known_pb_labels_expaned = known_pb_labels.clone() + known_bbox_expand = known_bboxs.clone() + + if noise_scale > 0 and self.training: + diff = torch.zeros_like(known_bbox_expand) + diff[:, :, :2] = known_bbox_expand[:, :, 2:] / 2 + diff[:, :, 2:] = known_bbox_expand[:, :, 2:] + # add very small noise to input points; no box + sc = 0.01 + for i, st in enumerate(box_start): + diff[i, :st] = diff[i, :st] * sc + known_bbox_expand += torch.mul( + (torch.rand_like(known_bbox_expand) * 2 - 1.0), + diff) * noise_scale + + known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) + + input_label_embed = self.pb_embedding(known_pb_labels_expaned) + + input_bbox_embed = inverse_sigmoid(known_bbox_expand) + + input_label_embed = input_label_embed.repeat_interleave( + self.num_mask_tokens, + 1) + self.mask_tokens.weight.unsqueeze(0).repeat( + input_label_embed.shape[0], input_label_embed.shape[1], 1) + input_bbox_embed = input_bbox_embed.repeat_interleave( + self.num_mask_tokens, 1) + + single_pad = self.num_mask_tokens + + # NOTE scalar is modified to 100, each click cannot see each other + scalar = int(input_label_embed.shape[1] / self.num_mask_tokens) + + pad_size = input_label_embed.shape[1] + + if input_label_embed.shape[1] > 0: + input_query_label = input_label_embed + input_query_bbox = input_bbox_embed + + tgt_size = pad_size + attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0 + # match query cannot see the reconstruct + attn_mask[pad_size:, :pad_size] = True + # reconstruct cannot see each other + for i in range(scalar): + if i == 0: + attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True + if i == scalar - 1: + attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True + else: + attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True + attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True + mask_dict = { + 'known_lbs_bboxes': (known_labels, known_bboxs), + 'pad_size': pad_size, + 'scalar': scalar, + } + return input_query_label, input_query_bbox, attn_mask, mask_dict diff --git a/seg/models/task_modules/cost.py b/seg/models/task_modules/cost.py new file mode 100644 index 0000000000000000000000000000000000000000..55533d5660b9558b609aba0a6c2cacf6caa9b60e --- /dev/null +++ b/seg/models/task_modules/cost.py @@ -0,0 +1,45 @@ +from typing import Optional, Union + +import torch +from mmdet.models.task_modules.assigners.match_cost import BaseMatchCost +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class FlexibleClassificationCost(BaseMatchCost): + def __init__(self, weight: Union[float, int] = 1) -> None: + super().__init__(weight=weight) + + def __call__(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: Optional[dict] = None, + **kwargs) -> Tensor: + """Compute match cost. + + Args: + pred_instances (:obj:`InstanceData`): ``scores`` inside is + predicted classification logits, of shape + (num_queries, num_class). + gt_instances (:obj:`InstanceData`): ``labels`` inside should have + shape (num_gt, ). + img_meta (Optional[dict]): _description_. Defaults to None. + + Returns: + Tensor: Match Cost matrix of shape (num_preds, num_gts). + """ + _pred_scores = pred_instances.scores + gt_labels = gt_instances.labels + + pred_scores = _pred_scores[..., :-1] + iou_score = _pred_scores[..., -1:] + + pred_scores = pred_scores.softmax(-1) + iou_score = iou_score.sigmoid() + pred_scores = torch.cat([pred_scores, iou_score], dim=-1) + cls_cost = -pred_scores[:, gt_labels] + + return cls_cost * self.weight diff --git a/seg/models/utils/__init__.py b/seg/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d6c603b76c55d6d35f52be2063b49568ff7ecd --- /dev/null +++ b/seg/models/utils/__init__.py @@ -0,0 +1,7 @@ +from .video_gt_preprocess import preprocess_video_panoptic_gt +from .mask_pool import mask_pool +from .pan_seg_transform import INSTANCE_OFFSET_HB, mmpan2hbpan, mmgt2hbpan +from .class_overlapping import calculate_class_overlapping +from .online_pq_utils import cal_pq, IoUObj, NO_OBJ_ID +from .no_obj import NO_OBJ +from .offline_video_metrics import vpq_eval, stq diff --git a/seg/models/utils/class_overlapping.py b/seg/models/utils/class_overlapping.py new file mode 100644 index 0000000000000000000000000000000000000000..8d9835bbdda71b9ead3987bf9b876a65c58d763a --- /dev/null +++ b/seg/models/utils/class_overlapping.py @@ -0,0 +1,14 @@ +from typing import List + + +def calculate_class_overlapping(classes1: List[str], classes2: List[str]) -> List[bool]: + words1 = [word for item in classes1 for word in item.split(',')] + results = [] + for item in classes2: + flag: bool = False + for word in item.split(','): + if word in words1: + flag = True + break + results.append(flag) + return results diff --git a/seg/models/utils/load_checkpoint.py b/seg/models/utils/load_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..12e165a3f58814d7f9b933a68857c68a09a0971b --- /dev/null +++ b/seg/models/utils/load_checkpoint.py @@ -0,0 +1,38 @@ +from mmengine.runner.checkpoint import CheckpointLoader + + +def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'): + """Load partial pretrained model with specific prefix. + + Args: + prefix (str): The prefix of sub-module. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. + Defaults to None. + logger: logger + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + + checkpoint = CheckpointLoader.load_checkpoint(filename, map_location=map_location, logger=logger) + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if not prefix: + return state_dict + if not prefix.endswith('.'): + prefix += '.' + prefix_len = len(prefix) + + state_dict = { + k[prefix_len:]: v + for k, v in state_dict.items() if k.startswith(prefix) + } + + assert state_dict, f'{prefix} is not in the pretrained model' + return state_dict diff --git a/seg/models/utils/mask_pool.py b/seg/models/utils/mask_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..06a315b50921703b60ba402a87f047a4e4033119 --- /dev/null +++ b/seg/models/utils/mask_pool.py @@ -0,0 +1,27 @@ +import torch +import torch.nn.functional as F + + +# https://github.com/NVlabs/ODISE/blob/e97b06c424c575fec9fc5368dd4b3e050d91abc4/odise/modeling/meta_arch/odise.py#L923 + +def mask_pool(x, mask): + """ + Args: + x: [B, C, H, W] + mask: [B, Q, H, W] + """ + if not x.shape[-2:] == mask.shape[-2:]: + # reshape mask to x + mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False) + with torch.no_grad(): + mask = mask.detach() + mask = (mask > 0).to(mask.dtype) + denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8 + + mask_pooled_x = torch.einsum( + "bchw,bqhw->bqc", + x, + mask / denorm, + ) + return mask_pooled_x + diff --git a/seg/models/utils/no_obj.py b/seg/models/utils/no_obj.py new file mode 100644 index 0000000000000000000000000000000000000000..8f4788486b1589905399634f1f484063fb2eee15 --- /dev/null +++ b/seg/models/utils/no_obj.py @@ -0,0 +1 @@ +NO_OBJ = 65535 diff --git a/seg/models/utils/offline_video_metrics.py b/seg/models/utils/offline_video_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..2c2a32b1095f58fa7463a77699e1ee419f3a4b6d --- /dev/null +++ b/seg/models/utils/offline_video_metrics.py @@ -0,0 +1,114 @@ +import numpy as np + +from seg.models.utils import NO_OBJ, INSTANCE_OFFSET_HB + + +def vpq_eval(element, num_classes=-1, max_ins=INSTANCE_OFFSET_HB, ign_id=NO_OBJ): + assert num_classes != -1 + import six + pred_ids, gt_ids = element + offset = 1e7 # 1e7 > 200 * max_ins + assert offset > num_classes * max_ins + num_cat = num_classes + 1 + + iou_per_class = np.zeros(num_cat, dtype=np.float64) + tp_per_class = np.zeros(num_cat, dtype=np.float64) + fn_per_class = np.zeros(num_cat, dtype=np.float64) + fp_per_class = np.zeros(num_cat, dtype=np.float64) + + def _ids_to_counts(id_array): + ids, counts = np.unique(id_array, return_counts=True) + return dict(six.moves.zip(ids, counts)) + + pred_areas = _ids_to_counts(pred_ids) + gt_areas = _ids_to_counts(gt_ids) + + void_id = ign_id * max_ins + ign_ids = { + gt_id for gt_id in six.iterkeys(gt_areas) + if (gt_id // max_ins) == ign_id + } + + int_ids = gt_ids.astype(np.uint64) * offset + pred_ids.astype(np.uint64) + int_areas = _ids_to_counts(int_ids) + + def prediction_void_overlap(pred_id): + void_int_id = void_id * offset + pred_id + return int_areas.get(void_int_id, 0) + + def prediction_ignored_overlap(pred_id): + total_ignored_overlap = 0 + for _ign_id in ign_ids: + int_id = _ign_id * offset + pred_id + total_ignored_overlap += int_areas.get(int_id, 0) + return total_ignored_overlap + + gt_matched = set() + pred_matched = set() + + for int_id, int_area in six.iteritems(int_areas): + gt_id = int(int_id // offset) + gt_cat = int(gt_id // max_ins) + pred_id = int(int_id % offset) + pred_cat = int(pred_id // max_ins) + if gt_cat != pred_cat: + continue + union = ( + gt_areas[gt_id] + pred_areas[pred_id] - int_area - + prediction_void_overlap(pred_id) + ) + iou = int_area / union + if iou > 0.5: + tp_per_class[gt_cat] += 1 + iou_per_class[gt_cat] += iou + gt_matched.add(gt_id) + pred_matched.add(pred_id) + + for gt_id in six.iterkeys(gt_areas): + if gt_id in gt_matched: + continue + cat_id = gt_id // max_ins + if cat_id == ign_id: + continue + fn_per_class[cat_id] += 1 + + for pred_id in six.iterkeys(pred_areas): + if pred_id in pred_matched: + continue + if (prediction_ignored_overlap(pred_id) / pred_areas[pred_id]) > 0.5: + continue + cat = pred_id // max_ins + fp_per_class[cat] += 1 + + return iou_per_class, tp_per_class, fn_per_class, fp_per_class + + +def stq(element, num_classes=19, max_ins=10000, ign_id=NO_OBJ, num_things=8, label_divisor=1e4, ins_divisor=1e7): + y_pred, y_true = element + y_true = y_true.astype(np.int64) + y_pred = y_pred.astype(np.int64) + + # semantic eval + semantic_label = y_true // max_ins + semantic_prediction = y_pred // max_ins + semantic_label = np.where(semantic_label != ign_id, + semantic_label, num_classes) + semantic_prediction = np.where(semantic_prediction != ign_id, + semantic_prediction, num_classes) + semantic_ids = np.reshape(semantic_label, [-1]) * label_divisor + np.reshape(semantic_prediction, [-1]) + + # instance eval + instance_label = y_true % max_ins + label_mask = np.less(semantic_label, num_things) + prediction_mask = np.less(semantic_label, num_things) + is_crowd = np.logical_and(instance_label == 0, label_mask) + + label_mask = np.logical_and(label_mask, np.logical_not(is_crowd)) + prediction_mask = np.logical_and(prediction_mask, np.logical_not(is_crowd)) + + seq_preds = y_pred[prediction_mask] + seg_labels = y_true[label_mask] + + non_crowd_intersection = np.logical_and(label_mask, prediction_mask) + intersection_ids = (y_true[non_crowd_intersection] * ins_divisor + y_pred[non_crowd_intersection]) + return semantic_ids, seq_preds, seg_labels, intersection_ids diff --git a/seg/models/utils/online_pq_utils.py b/seg/models/utils/online_pq_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..283dba0c1c41423c09d277fa5574e7763cf9c41f --- /dev/null +++ b/seg/models/utils/online_pq_utils.py @@ -0,0 +1,73 @@ +from seg.models.utils.no_obj import NO_OBJ +from seg.models.utils.pan_seg_transform import INSTANCE_OFFSET_HB +from panopticapi.evaluation import PQStat + +NO_OBJ_ID = NO_OBJ * INSTANCE_OFFSET_HB + + +class IoUObj: + def __init__(self, intersection: int = 0, union: int = 0): + self.intersection = intersection + self.union = union + + def __iadd__(self, other): + self.intersection += other.intersection + self.union += other.union + return self + + def __isub__(self, other): + self.intersection -= other.intersection + self.union -= other.union + return self + + def is_legal(self): + return self.intersection >= 0 and self.union >= 0 + + @property + def iou(self): + return self.intersection / self.union + + +def cal_pq(global_intersection_info, classes): + num_classes = len(classes) + gt_matched = set() + pred_matched = set() + + gt_all = set() + pred_all = set() + + pq_stat = PQStat() + for gt_id, pred_id in global_intersection_info: + gt_cat = gt_id // INSTANCE_OFFSET_HB + pred_cat = pred_id // INSTANCE_OFFSET_HB + assert pred_cat < num_classes + if global_intersection_info[gt_id, pred_id].union == 0: + continue + if gt_cat == NO_OBJ: + continue + gt_all.add(gt_id) + pred_all.add(pred_id) + if gt_cat != pred_cat: + continue + iou = global_intersection_info[gt_id, pred_id].iou + if iou > 0.5: + pq_stat[gt_cat].tp += 1 + pq_stat[gt_cat].iou += iou + gt_matched.add(gt_id) + pred_matched.add(pred_id) + + for gt_id in gt_all: + gt_cat = gt_id // INSTANCE_OFFSET_HB + if gt_id in gt_matched: + continue + pq_stat[gt_cat].fn += 1 + + for pred_id in pred_all: + pred_cat = pred_id // INSTANCE_OFFSET_HB + if pred_id in pred_matched: + continue + if global_intersection_info[NO_OBJ_ID, pred_id].iou > 0.5: + continue + pq_stat[pred_cat].fp += 1 + + return pq_stat diff --git a/seg/models/utils/pan_seg_transform.py b/seg/models/utils/pan_seg_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..941dcb6e877596d7a54a359fe700234366ce98f0 --- /dev/null +++ b/seg/models/utils/pan_seg_transform.py @@ -0,0 +1,36 @@ +import copy + +import torch +import numpy as np +from mmdet.evaluation import INSTANCE_OFFSET + +INSTANCE_OFFSET_HB = 10000 + + +def mmpan2hbpan(pred_pan_map, num_classes): + pan_seg_map = - np.ones_like(pred_pan_map) + for itm in np.unique(pred_pan_map): + if itm >= INSTANCE_OFFSET: + # cls labels (from segmentation maps) + cls = itm % INSTANCE_OFFSET + # id labels (from tracking maps) + ins = itm // INSTANCE_OFFSET + pan_seg_map[pred_pan_map == itm] = cls * INSTANCE_OFFSET_HB + ins + elif itm == num_classes: + pan_seg_map[pred_pan_map == itm] = num_classes * INSTANCE_OFFSET_HB + else: + pan_seg_map[pred_pan_map == itm] = itm * INSTANCE_OFFSET_HB + assert -1 not in pan_seg_map + return pan_seg_map + + +def mmgt2hbpan(data_samples): + pan_map = copy.deepcopy(data_samples.gt_sem_seg.sem_seg[0]) + pan_map = pan_map * INSTANCE_OFFSET_HB + gt_instances = data_samples.gt_instances + for idx in range(len(gt_instances)): + mask = torch.tensor(gt_instances.masks.masks[idx], dtype=torch.bool) + instance_id = gt_instances.instances_ids[idx].item() + pan_map[mask] = instance_id + + return pan_map diff --git a/seg/models/utils/video_gt_preprocess.py b/seg/models/utils/video_gt_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..83dfa2a433f04592e662e9191a6c06d3c64ad0dc --- /dev/null +++ b/seg/models/utils/video_gt_preprocess.py @@ -0,0 +1,87 @@ +import torch + + +def preprocess_video_panoptic_gt( + gt_labels, + gt_masks, + gt_semantic_seg, + gt_instance_ids, + num_things, + num_stuff, +): + num_classes = num_things + num_stuff + num_frames = len(gt_masks) + mask_size = gt_masks[0].masks.shape[-2:] + + thing_masks_list = [] + for frame_id in range(num_frames): + thing_masks_list.append(gt_masks[frame_id].pad( + mask_size, pad_val=0).to_tensor( + dtype=torch.bool, device=gt_labels.device) + ) + instances = torch.unique(gt_instance_ids[:, 1]) + things_masks = [] + labels = [] + for instance in instances: + pos_ins = torch.nonzero(torch.eq(gt_instance_ids[:, 1], instance), as_tuple=True)[0] # 0 is for redundant tuple + labels_instance = gt_labels[:, 1][pos_ins] + assert torch.allclose(labels_instance, labels_instance[0]) + labels.append(labels_instance[0]) + instance_frame_ids = gt_instance_ids[:, 0][pos_ins].to(dtype=torch.int32).tolist() + instance_masks = [] + for frame_id in range(num_frames): + frame_instance_ids = gt_instance_ids[gt_instance_ids[:, 0] == frame_id, 1] + if frame_id not in instance_frame_ids: + empty_mask = torch.zeros( + mask_size, + dtype=thing_masks_list[frame_id].dtype, device=thing_masks_list[frame_id].device + ) + instance_masks.append(empty_mask) + else: + pos_inner_frame = torch.nonzero(torch.eq(frame_instance_ids, instance), as_tuple=True)[0].item() + frame_mask = thing_masks_list[frame_id][pos_inner_frame] + instance_masks.append(frame_mask) + things_masks.append(torch.stack(instance_masks)) + + if len(instances) == 0: + things_masks = torch.stack(thing_masks_list, dim=1) + labels = torch.empty_like(instances) + else: + things_masks = torch.stack(things_masks) + labels = torch.stack(labels) + assert torch.all(torch.less(labels, num_things)) + + if gt_semantic_seg is not None: + things_labels = labels + gt_semantic_seg = gt_semantic_seg.squeeze(1) + + semantic_labels = torch.unique( + gt_semantic_seg, + sorted=False, + return_inverse=False, + return_counts=False) + stuff_masks_list = [] + stuff_labels_list = [] + for label in semantic_labels: + if label < num_things or label >= num_classes: + continue + stuff_mask = gt_semantic_seg == label + stuff_masks_list.append(stuff_mask) + stuff_labels_list.append(label) + + if len(stuff_masks_list) > 0: + stuff_masks = torch.stack(stuff_masks_list, dim=0) + stuff_labels = torch.stack(stuff_labels_list, dim=0) + assert torch.all(torch.ge(stuff_labels, num_things)) and torch.all(torch.less(stuff_labels, num_classes)) + labels = torch.cat([things_labels, stuff_labels], dim=0) + masks = torch.cat([things_masks, stuff_masks], dim=0) + else: + labels = things_labels + masks = things_masks + assert len(labels) == len(masks) + else: + masks = things_masks + + labels = labels.to(dtype=torch.long) + masks = masks.to(dtype=torch.long) + return labels, masks