Spaces:
Sleeping
Sleeping
File size: 1,109 Bytes
591ba45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import argparse
import glob
import os
import warnings
import cv2
import numpy as np
import skimage.io as io
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from .GeoTr import U2NETP, GeoTr
warnings.filterwarnings("ignore")
class GeoTrP(nn.Module):
def __init__(self):
super(GeoTrP, self).__init__()
self.GeoTr = GeoTr()
def forward(self, x):
bm = self.GeoTr(x) # [0]
bm = 2 * (bm / 288) - 1
bm = (bm + 1) / 2 * 2560
bm = F.interpolate(bm, size=(2560, 2560), mode="bilinear", align_corners=True)
return bm
def reload_model(model, path=""):
if not bool(path):
return model
else:
model_dict = model.state_dict()
pretrained_dict = torch.load(path, map_location="cuda:0")
print(len(pretrained_dict.keys()))
print(len(pretrained_dict.keys()))
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
|