File size: 2,341 Bytes
3f9c56c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
from annotator.annotator_path import models_path
from modules import devices
from annotator.uniformer.inference import init_segmentor, inference_segmentor, show_result_pyplot

try:
    from mmseg.core.evaluation import get_palette
except ImportError:
    from annotator.mmpkg.mmseg.core.evaluation import get_palette

modeldir = os.path.join(models_path, "uniformer")
checkpoint_file = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/upernet_global_small.pth"
config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),  "upernet_global_small.py")
old_modeldir = os.path.dirname(os.path.realpath(__file__))
model = None

def unload_uniformer_model():
    global model
    if model is not None:
        model = model.cpu()

def apply_uniformer(img):
    global model
    if model is None:
        modelpath = os.path.join(modeldir, "upernet_global_small.pth")
        old_modelpath = os.path.join(old_modeldir, "upernet_global_small.pth")
        if os.path.exists(old_modelpath):
            modelpath = old_modelpath  
        elif not os.path.exists(modelpath):
            from basicsr.utils.download_util import load_file_from_url
            load_file_from_url(checkpoint_file, model_dir=modeldir)
            
        model = init_segmentor(config_file, modelpath, device=devices.get_device_for("controlnet"))
    model = model.to(devices.get_device_for("controlnet"))
    
    if devices.get_device_for("controlnet").type == 'mps':
        # adaptive_avg_pool2d can fail on MPS, workaround with CPU
        import torch.nn.functional
        
        orig_adaptive_avg_pool2d = torch.nn.functional.adaptive_avg_pool2d
        def cpu_if_exception(input, *args, **kwargs):
            try:
                return orig_adaptive_avg_pool2d(input, *args, **kwargs)
            except:
                return orig_adaptive_avg_pool2d(input.cpu(), *args, **kwargs).to(input.device)
        
        try:
            torch.nn.functional.adaptive_avg_pool2d = cpu_if_exception
            result = inference_segmentor(model, img)
        finally:
            torch.nn.functional.adaptive_avg_pool2d = orig_adaptive_avg_pool2d
    else:
        result = inference_segmentor(model, img)
    
    res_img = show_result_pyplot(model, img, result, get_palette('ade'), opacity=1)
    return res_img