| import os |
|
|
| from annotator.uniformer.mmseg.apis import init_segmentor, inference_segmentor, show_result_pyplot |
| from annotator.uniformer.mmseg.core.evaluation import get_palette |
| from annotator.util import annotator_ckpts_path |
|
|
|
|
| checkpoint_file = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/upernet_global_small.pth" |
|
|
|
|
| class UniformerDetector: |
| def __init__(self): |
| modelpath = os.path.join(annotator_ckpts_path, "upernet_global_small.pth") |
| if not os.path.exists(annotator_ckpts_path): |
| os.makedirs(annotator_ckpts_path) |
| |
| if not os.path.exists(modelpath): |
| from torch.hub import download_url_to_file |
| print(f"Downloading upernet_global_small from {checkpoint_file}...") |
| download_url_to_file(checkpoint_file, modelpath) |
| |
| config_file = os.path.join(os.path.dirname(annotator_ckpts_path), "uniformer", "exp", "upernet_global_small", "config.py") |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.model = init_segmentor(config_file, modelpath, device=device) |
|
|
| def __call__(self, img): |
| result = inference_segmentor(self.model, img) |
| res_img = show_result_pyplot(self.model, img, result, get_palette('ade'), opacity=1) |
| return res_img |
|
|