import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import argparse import glob import os import warnings import cv2 import numpy as np import skimage.io as io import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from .GeoTr import U2NETP, GeoTr warnings.filterwarnings("ignore") class GeoTrP(nn.Module): def __init__(self): super(GeoTrP, self).__init__() self.GeoTr = GeoTr() def forward(self, x): bm = self.GeoTr(x) # [0] bm = 2 * (bm / 288) - 1 bm = (bm + 1) / 2 * 2560 bm = F.interpolate(bm, size=(2560, 2560), mode="bilinear", align_corners=True) return bm def reload_model(model, path=""): if not bool(path): return model else: model_dict = model.state_dict() pretrained_dict = torch.load(path, map_location="cuda:0") print(len(pretrained_dict.keys())) print(len(pretrained_dict.keys())) model_dict.update(pretrained_dict) model.load_state_dict(model_dict) return model