RAVE / annotator /normalbae /__init__.py
ozgurkara's picture
first commit
eb9a9b4
raw
history blame contribute delete
No virus
2.76 kB
import os
import types
import torch
import numpy as np
from einops import rearrange
from .models.NNET import NNET
from modules import devices
from annotator.annotator_path import models_path
import torchvision.transforms as transforms
# load model
def load_checkpoint(fpath, model):
ckpt = torch.load(fpath, map_location='cpu')['model']
load_dict = {}
for k, v in ckpt.items():
if k.startswith('module.'):
k_ = k.replace('module.', '')
load_dict[k_] = v
else:
load_dict[k] = v
model.load_state_dict(load_dict)
return model
class NormalBaeDetector:
model_dir = os.path.join(models_path, "normal_bae")
def __init__(self):
self.model = None
self.device = devices.get_device_for("controlnet")
def load_model(self):
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/scannet.pt"
modelpath = os.path.join(self.model_dir, "scannet.pt")
if not os.path.exists(modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_model_path, model_dir=self.model_dir)
args = types.SimpleNamespace()
args.mode = 'client'
args.architecture = 'BN'
args.pretrained = 'scannet'
args.sampling_ratio = 0.4
args.importance_ratio = 0.7
model = NNET(args)
model = load_checkpoint(modelpath, model)
model.eval()
self.model = model.to(self.device)
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def unload_model(self):
if self.model is not None:
self.model.cpu()
def __call__(self, input_image):
if self.model is None:
self.load_model()
self.model.to(self.device)
assert input_image.ndim == 3
image_normal = input_image
with torch.no_grad():
image_normal = torch.from_numpy(image_normal).float().to(self.device)
image_normal = image_normal / 255.0
image_normal = rearrange(image_normal, 'h w c -> 1 c h w')
image_normal = self.norm(image_normal)
normal = self.model(image_normal)
normal = normal[0][-1][:, :3]
# d = torch.sum(normal ** 2.0, dim=1, keepdim=True) ** 0.5
# d = torch.maximum(d, torch.ones_like(d) * 1e-5)
# normal /= d
normal = ((normal + 1) * 0.5).clip(0, 1)
normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy()
normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
return normal_image