import sys, os, distutils.core # os.system('python -m pip install pyyaml==5.3.1') # dist = distutils.core.run_setup("./detectron2/setup.py") # temp = ' '.join([f"'{x}'" for x in dist.install_requires]) # cmd = "python -m pip install {0}".format(temp) # os.system(cmd) sys.path.insert(0, os.path.abspath('./detectron2')) import detectron2 import cv2 from detectron2.utils.logger import setup_logger setup_logger() # from detectron2.modeling import build_model from detectron2 import model_zoo from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer from detectron2.data import MetadataCatalog, DatasetCatalog from detectron2.utils.visualizer import Visualizer from detectron2.checkpoint import DetectionCheckpointer from detectron2.data.datasets import register_coco_instances def get_splash_detector(): cfg = get_cfg() cfg.OUTPUT_DIR = "./output/splash/" # model = build_model(cfg) # returns a torch.nn.Module cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) cfg.DATASETS.TRAIN = ("splash_trains",) cfg.DATASETS.TEST = () cfg.DATALOADER.NUM_WORKERS = 2 cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") # Let training initialize from model zoo cfg.SOLVER.IMS_PER_BATCH = 2 # This is the real "batch size" commonly known to deep learning people cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR cfg.SOLVER.MAX_ITER = 300 # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset cfg.SOLVER.STEPS = [] # do not decay learning rate cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512) cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set a custom testing threshold predictor = DefaultPredictor(cfg) return predictor