fffiloni commited on
Commit
69952eb
1 Parent(s): e88a625

Added custom models option

Browse files
Files changed (1) hide show
  1. model.py +31 -37
model.py CHANGED
@@ -12,6 +12,7 @@ import cv2
12
  import einops
13
  import numpy as np
14
  import torch
 
15
  from pytorch_lightning import seed_everything
16
 
17
  sys.path.append('ControlNet')
@@ -28,19 +29,8 @@ from cldm.model import create_model, load_state_dict
28
  from ldm.models.diffusion.ddim import DDIMSampler
29
  from share import *
30
 
31
- ORIGINAL_MODEL_NAMES = {
32
- 'canny': 'control_sd15_canny.pth',
33
- 'hough': 'control_sd15_mlsd.pth',
34
- 'hed': 'control_sd15_hed.pth',
35
- 'scribble': 'control_sd15_scribble.pth',
36
- 'pose': 'control_sd15_openpose.pth',
37
- 'seg': 'control_sd15_seg.pth',
38
- 'depth': 'control_sd15_depth.pth',
39
- 'normal': 'control_sd15_normal.pth',
40
- }
41
- ORIGINAL_WEIGHT_ROOT = 'https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/'
42
 
43
- LIGHTWEIGHT_MODEL_NAMES = {
44
  'canny': 'control_canny-fp16.safetensors',
45
  'hough': 'control_mlsd-fp16.safetensors',
46
  'hed': 'control_hed-fp16.safetensors',
@@ -50,36 +40,44 @@ LIGHTWEIGHT_MODEL_NAMES = {
50
  'depth': 'control_depth-fp16.safetensors',
51
  'normal': 'control_normal-fp16.safetensors',
52
  }
53
- LIGHTWEIGHT_WEIGHT_ROOT = 'https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/'
54
 
 
 
 
 
 
55
 
56
  class Model:
57
  def __init__(self,
58
  model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
59
- model_dir: str = 'models',
60
- use_lightweight: bool = True):
61
  self.device = torch.device(
62
  'cuda:0' if torch.cuda.is_available() else 'cpu')
63
  self.model = create_model(model_config_path).to(self.device)
64
  self.ddim_sampler = DDIMSampler(self.model)
65
  self.task_name = ''
66
-
 
 
67
  self.model_dir = pathlib.Path(model_dir)
68
  self.model_dir.mkdir(exist_ok=True, parents=True)
69
 
70
- self.use_lightweight = use_lightweight
71
- if use_lightweight:
72
- self.model_names = LIGHTWEIGHT_MODEL_NAMES
73
- self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
74
- base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
75
- self.load_base_model(base_model_url)
76
- else:
77
- self.model_names = ORIGINAL_MODEL_NAMES
78
- self.weight_root = ORIGINAL_WEIGHT_ROOT
79
-
80
  self.download_models()
 
 
 
 
 
 
 
 
 
 
 
81
 
 
82
  def download_base_model(self, model_url: str) -> pathlib.Path:
 
83
  model_name = model_url.split('/')[-1]
84
  out_path = self.model_dir / model_name
85
  if not out_path.exists():
@@ -96,27 +94,23 @@ class Model:
96
  if task_name == self.task_name:
97
  return
98
  weight_path = self.get_weight_path(task_name)
99
- if not self.use_lightweight:
100
- self.model.load_state_dict(
101
- load_state_dict(weight_path, location=self.device))
102
- else:
103
- self.model.control_model.load_state_dict(
104
- load_state_dict(weight_path, location=self.device.type))
105
  self.task_name = task_name
106
 
107
  def get_weight_path(self, task_name: str) -> str:
108
  if 'scribble' in task_name:
109
  task_name = 'scribble'
110
- return f'{self.model_dir}/{self.model_names[task_name]}'
111
 
112
  def download_models(self) -> None:
113
  self.model_dir.mkdir(exist_ok=True, parents=True)
114
- for name in self.model_names.values():
115
  out_path = self.model_dir / name
116
  if out_path.exists():
117
  continue
118
- subprocess.run(
119
- shlex.split(f'wget {self.weight_root}{name} -O {out_path}'))
120
 
121
  @torch.inference_mode()
122
  def process_canny(self, input_image, prompt, a_prompt, n_prompt,
@@ -763,4 +757,4 @@ class Model:
763
  127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
764
 
765
  results = [x_samples[i] for i in range(num_samples)]
766
- return [detected_map] + results
12
  import einops
13
  import numpy as np
14
  import torch
15
+ from huggingface_hub import hf_hub_url
16
  from pytorch_lightning import seed_everything
17
 
18
  sys.path.append('ControlNet')
29
  from ldm.models.diffusion.ddim import DDIMSampler
30
  from share import *
31
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ MODEL_NAMES = {
34
  'canny': 'control_canny-fp16.safetensors',
35
  'hough': 'control_mlsd-fp16.safetensors',
36
  'hed': 'control_hed-fp16.safetensors',
40
  'depth': 'control_depth-fp16.safetensors',
41
  'normal': 'control_normal-fp16.safetensors',
42
  }
 
43
 
44
+ MODEL_REPO = 'webui/ControlNet-modules-safetensors'
45
+
46
+ DEFAULT_BASE_MODEL_REPO = 'runwayml/stable-diffusion-v1-5'
47
+ DEFAULT_BASE_MODEL_FILENAME = 'v1-5-pruned-emaonly.safetensors'
48
+ DEFAULT_BASE_MODEL_URL = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
49
 
50
  class Model:
51
  def __init__(self,
52
  model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
53
+ model_dir: str = 'models'):
 
54
  self.device = torch.device(
55
  'cuda:0' if torch.cuda.is_available() else 'cpu')
56
  self.model = create_model(model_config_path).to(self.device)
57
  self.ddim_sampler = DDIMSampler(self.model)
58
  self.task_name = ''
59
+
60
+ self.base_model_url = ''
61
+
62
  self.model_dir = pathlib.Path(model_dir)
63
  self.model_dir.mkdir(exist_ok=True, parents=True)
64
 
 
 
 
 
 
 
 
 
 
 
65
  self.download_models()
66
+ self.set_base_model(DEFAULT_BASE_MODEL_REPO,
67
+ DEFAULT_BASE_MODEL_FILENAME)
68
+
69
+ def set_base_model(self, model_id: str, filename: str) -> str:
70
+ if not model_id or not filename:
71
+ return self.base_model_url
72
+ base_model_url = hf_hub_url(model_id, filename)
73
+ if base_model_url != self.base_model_url:
74
+ self.load_base_model(base_model_url)
75
+ self.base_model_url = base_model_url
76
+ return self.base_model_url
77
 
78
+
79
  def download_base_model(self, model_url: str) -> pathlib.Path:
80
+ self.model_dir.mkdir(exist_ok=True, parents=True)
81
  model_name = model_url.split('/')[-1]
82
  out_path = self.model_dir / model_name
83
  if not out_path.exists():
94
  if task_name == self.task_name:
95
  return
96
  weight_path = self.get_weight_path(task_name)
97
+ self.model.control_model.load_state_dict(
98
+ load_state_dict(weight_path, location=self.device.type))
 
 
 
 
99
  self.task_name = task_name
100
 
101
  def get_weight_path(self, task_name: str) -> str:
102
  if 'scribble' in task_name:
103
  task_name = 'scribble'
104
+ return f'{self.model_dir}/{MODEL_NAMES[task_name]}'
105
 
106
  def download_models(self) -> None:
107
  self.model_dir.mkdir(exist_ok=True, parents=True)
108
+ for name in MODEL_NAMES.values():
109
  out_path = self.model_dir / name
110
  if out_path.exists():
111
  continue
112
+ model_url = hf_hub_url(MODEL_REPO, name)
113
+ subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
114
 
115
  @torch.inference_mode()
116
  def process_canny(self, input_image, prompt, a_prompt, n_prompt,
757
  127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
758
 
759
  results = [x_samples[i] for i in range(num_samples)]
760
+ return [detected_map] + results